Skip to content

[SPARK-18214][SQL] Simplify RuntimeReplaceable type coercion #15723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,6 @@ object TypeCoercion {
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
NaNvl(Cast(l, DoubleType), r)

case e: RuntimeReplaceable => e.replaceForTypeCoercion()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ abstract class Expression extends TreeNode[Expression] {
*/
def prettyName: String = nodeName.toLowerCase

protected def flatArguments = productIterator.flatMap {
protected def flatArguments: Iterator[Any] = productIterator.flatMap {
case t: Traversable[_] => t
case single => single :: Nil
}
Expand Down Expand Up @@ -229,26 +229,16 @@ trait Unevaluable extends Expression {
* An expression that gets replaced at runtime (currently by the optimizer) into a different
* expression for evaluation. This is mainly used to provide compatibility with other databases.
* For example, we use this to support "nvl" by replacing it with "coalesce".
*
* A RuntimeReplaceable should have the original parameters along with a "child" expression in the
* case class constructor, and define a normal constructor that accepts only the original
* parameters. For an example, see [[Nvl]]. To make sure the explain plan and expression SQL
* works correctly, the implementation should also override flatArguments method and sql method.
*/
trait RuntimeReplaceable extends Unevaluable {
/**
* Method for concrete implementations to override that specifies how to construct the expression
* that should replace the current one.
*/
def replaceForEvaluation(): Expression

/**
* Method for concrete implementations to override that specifies how to coerce the input types.
*/
def replaceForTypeCoercion(): Expression

/** The expression that should be used during evaluation. */
lazy val replaced: Expression = replaceForEvaluation()

override def nullable: Boolean = replaced.nullable
override def foldable: Boolean = replaced.foldable
override def dataType: DataType = replaced.dataType
override def checkInputDataTypes(): TypeCheckResult = replaced.checkInputDataTypes()
trait RuntimeReplaceable extends UnaryExpression with Unevaluable {
override def nullable: Boolean = child.nullable
override def foldable: Boolean = child.foldable
override def dataType: DataType = child.dataType
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,6 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
}""")
}
}

override def prettyName: String = "unix_time"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,78 +89,53 @@ case class Coalesce(children: Seq[Expression]) extends Expression {


@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceable {
override def children: Seq[Expression] = Seq(left, right)

override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))
case class IfNull(left: Expression, right: Expression, child: Expression)
extends RuntimeReplaceable {

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
this
}
def this(left: Expression, right: Expression) = {
this(left, right, Coalesce(Seq(left, right)))
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}


@ExpressionDescription(usage = "_FUNC_(a,b) - Returns null if a equals to b, or a otherwise.")
case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceable {
override def children: Seq[Expression] = Seq(left, right)
case class NullIf(left: Expression, right: Expression, child: Expression)
extends RuntimeReplaceable {

override def replaceForEvaluation(): Expression = {
If(EqualTo(left, right), Literal.create(null, left.dataType), left)
def this(left: Expression, right: Expression) = {
this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left))
}

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
this
}
}
override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}


@ExpressionDescription(usage = "_FUNC_(a,b) - Returns b if a is null, or a otherwise.")
case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable {
override def children: Seq[Expression] = Seq(left, right)
case class Nvl(left: Expression, right: Expression, child: Expression) extends RuntimeReplaceable {

override def replaceForEvaluation(): Expression = Coalesce(Seq(left, right))

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeToString(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
this
}
def this(left: Expression, right: Expression) = {
this(left, right, Coalesce(Seq(left, right)))
}

override def flatArguments: Iterator[Any] = Iterator(left, right)
override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"
}


@ExpressionDescription(usage = "_FUNC_(a,b,c) - Returns b if a is not null, or c otherwise.")
case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression)
case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression, child: Expression)
extends RuntimeReplaceable {

override def replaceForEvaluation(): Expression = If(IsNotNull(expr1), expr2, expr3)

override def children: Seq[Expression] = Seq(expr1, expr2, expr3)

override def replaceForTypeCoercion(): Expression = {
if (expr2.dataType != expr3.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype =>
copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype))
}.getOrElse(this)
} else {
this
}
def this(expr1: Expression, expr2: Expression, expr3: Expression) = {
this(expr1, expr2, expr3, If(IsNotNull(expr1), expr2, expr3))
}

override def flatArguments: Iterator[Any] = Iterator(expr1, expr2, expr3)
override def sql: String = s"$prettyName(${expr1.sql}, ${expr2.sql}, ${expr3.sql})"
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
*/
object ReplaceExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e: RuntimeReplaceable => e.replaced
case e: RuntimeReplaceable => e.child
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types._

class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -86,18 +88,23 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("SPARK-16602 Nvl should support numeric-string cases") {
def analyze(expr: Expression): Expression = {
val relation = LocalRelation()
SimpleAnalyzer.execute(Project(Seq(Alias(expr, "c")()), relation)).expressions.head
}

val intLit = Literal.create(1, IntegerType)
val doubleLit = Literal.create(2.2, DoubleType)
val stringLit = Literal.create("c", StringType)
val nullLit = Literal.create(null, NullType)

assert(Nvl(intLit, doubleLit).replaceForTypeCoercion().dataType == DoubleType)
assert(Nvl(intLit, stringLit).replaceForTypeCoercion().dataType == StringType)
assert(Nvl(stringLit, doubleLit).replaceForTypeCoercion().dataType == StringType)
assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType)
assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType)
assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType)

assert(Nvl(nullLit, intLit).replaceForTypeCoercion().dataType == IntegerType)
assert(Nvl(doubleLit, nullLit).replaceForTypeCoercion().dataType == DoubleType)
assert(Nvl(nullLit, stringLit).replaceForTypeCoercion().dataType == StringType)
assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType)
assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType)
assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType)
}

test("AtLeastNNonNulls") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
-- A test suite for functions added for compatibility with other databases such as Oracle, MSSQL.
-- These functions are typically implemented using the trait RuntimeReplaceable.

SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null);
SELECT nullif('x', 'x'), nullif('x', 'y');
SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null);
SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null);

-- type coercion
SELECT ifnull(1, 2.1d), ifnull(null, 2.1d);
SELECT nullif(1, 2.1d), nullif(1, 1.0d);
SELECT nvl(1, 2.1d), nvl(null, 2.1d);
SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d);

-- explain for these functions; use range to avoid constant folding
explain extended
select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y')
from range(2);

-- SPARK-16730 cast alias functions for Hive compatibility
SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1);
SELECT float(1), double(1), decimal(1);
SELECT date("2014-04-04"), timestamp(date("2014-04-04"));
-- error handling: only one argument
SELECT string(1, 2);
5 changes: 4 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/array.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 10
-- Number of queries: 12


-- !query 0
Expand Down Expand Up @@ -124,6 +124,7 @@ struct<sort_array(boolean_array, true):array<boolean>,sort_array(tinyint_array,
-- !query 8 output
[true] [1,2] [1,2] [1,2] [1,2] [9223372036854775808,9223372036854775809] [1.0,2.0] [1.0,2.0] [2016-03-13,2016-03-14] [2016-11-12 20:54:00.0,2016-11-15 20:54:00.0]


-- !query 9
select sort_array(array('b', 'd'), '1')
-- !query 9 schema
Expand All @@ -132,6 +133,7 @@ struct<>
org.apache.spark.sql.AnalysisException
cannot resolve 'sort_array(array('b', 'd'), '1')' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7


-- !query 10
select sort_array(array('b', 'd'), cast(NULL as boolean))
-- !query 10 schema
Expand All @@ -140,6 +142,7 @@ struct<>
org.apache.spark.sql.AnalysisException
cannot resolve 'sort_array(array('b', 'd'), CAST(NULL AS BOOLEAN))' due to data type mismatch: Sort order in second argument requires a boolean literal.; line 1 pos 7


-- !query 11
select
size(boolean_array),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 13


-- !query 0
SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)
-- !query 0 schema
struct<ifnull(NULL, 'x'):string,ifnull('y', 'x'):string,ifnull(NULL, NULL):null>
-- !query 0 output
x y NULL


-- !query 1
SELECT nullif('x', 'x'), nullif('x', 'y')
-- !query 1 schema
struct<nullif('x', 'x'):string,nullif('x', 'y'):string>
-- !query 1 output
NULL x


-- !query 2
SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)
-- !query 2 schema
struct<nvl(NULL, 'x'):string,nvl('y', 'x'):string,nvl(NULL, NULL):null>
-- !query 2 output
x y NULL


-- !query 3
SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)
-- !query 3 schema
struct<nvl2(NULL, 'x', 'y'):string,nvl2('n', 'x', 'y'):string,nvl2(NULL, NULL, NULL):null>
-- !query 3 output
y x NULL


-- !query 4
SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)
-- !query 4 schema
struct<ifnull(1, 2.1D):double,ifnull(NULL, 2.1D):double>
-- !query 4 output
1.0 2.1


-- !query 5
SELECT nullif(1, 2.1d), nullif(1, 1.0d)
-- !query 5 schema
struct<nullif(1, 2.1D):int,nullif(1, 1.0D):int>
-- !query 5 output
1 NULL


-- !query 6
SELECT nvl(1, 2.1d), nvl(null, 2.1d)
-- !query 6 schema
struct<nvl(1, 2.1D):double,nvl(NULL, 2.1D):double>
-- !query 6 output
1.0 2.1


-- !query 7
SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)
-- !query 7 schema
struct<nvl2(NULL, 1, 2.1D):double,nvl2('n', 1, 2.1D):double>
-- !query 7 output
2.1 1.0


-- !query 8
explain extended
select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y')
from range(2)
-- !query 8 schema
struct<plan:string>
-- !query 8 output
== Parsed Logical Plan ==
'Project [unresolvedalias('ifnull('id, x), None), unresolvedalias('nullif('id, x), None), unresolvedalias('nvl('id, x), None), unresolvedalias('nvl2('id, x, y), None)]
+- 'UnresolvedTableValuedFunction range, [2]

== Analyzed Logical Plan ==
ifnull(`id`, 'x'): string, nullif(`id`, 'x'): bigint, nvl(`id`, 'x'): string, nvl2(`id`, 'x', 'y'): string
Project [ifnull(id#xL, x) AS ifnull(`id`, 'x')#x, nullif(id#xL, x) AS nullif(`id`, 'x')#xL, nvl(id#xL, x) AS nvl(`id`, 'x')#x, nvl2(id#xL, x, y) AS nvl2(`id`, 'x', 'y')#x]
+- Range (0, 2, step=1, splits=None)

== Optimized Logical Plan ==
Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x]
+- Range (0, 2, step=1, splits=None)

== Physical Plan ==
*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x]
+- *Range (0, 2, step=1, splits=None)


-- !query 9
SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)
-- !query 9 schema
struct<CAST(1 AS BOOLEAN):boolean,CAST(1 AS TINYINT):tinyint,CAST(1 AS SMALLINT):smallint,CAST(1 AS INT):int,CAST(1 AS BIGINT):bigint>
-- !query 9 output
true 1 1 1 1


-- !query 10
SELECT float(1), double(1), decimal(1)
-- !query 10 schema
struct<CAST(1 AS FLOAT):float,CAST(1 AS DOUBLE):double,CAST(1 AS DECIMAL(10,0)):decimal(10,0)>
-- !query 10 output
1.0 1.0 1


-- !query 11
SELECT date("2014-04-04"), timestamp(date("2014-04-04"))
-- !query 11 schema
struct<CAST(2014-04-04 AS DATE):date,CAST(CAST(2014-04-04 AS DATE) AS TIMESTAMP):timestamp>
-- !query 11 output
2014-04-04 2014-04-04 00:00:00


-- !query 12
SELECT string(1, 2)
-- !query 12 schema
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
Function string accepts only one argument; line 1 pos 7
Loading