Skip to content

[SPARK-16714][SPARK-16735][SPARK-16646] array, map, greatest, least's type coercion should handle decimal type #14439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,6 @@ object TypeCoercion {
})
}

/**
* Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use
* [[findTightestCommonTypeToString]] to find the TightestCommonType.
*/
private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) =>
findTightestCommonTypeToString(d, c)
})
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
Expand Down Expand Up @@ -157,6 +145,28 @@ object TypeCoercion {
})
}

/**
* Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to
* [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type before return it.
*/
private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird that its name is findWiderTypeWithoutStringPromotion because findTightestCommonTypeOfTwo is used inside. Also, let's add more docs to this method.

types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ => None
})
case None => None
})
}

private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1

Expand Down Expand Up @@ -440,7 +450,7 @@ object TypeCoercion {

case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
findWiderCommonType(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
Expand All @@ -451,7 +461,7 @@ object TypeCoercion {
m.keys
} else {
val types = m.keys.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
findWiderCommonType(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
Expand All @@ -461,7 +471,7 @@ object TypeCoercion {
m.values
} else {
val types = m.values.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
findWiderCommonType(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
Expand Down Expand Up @@ -494,16 +504,19 @@ object TypeCoercion {
case None => c
}

// When finding wider type for `Greatest` and `Least`, we should handle decimal types even if
// we need to truncate, but we should not promote one side to string if the other side is
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case None => l
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,24 @@ class TypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), StringType)
:: Cast(Literal("a"), StringType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal(1)
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
:: Literal(1).cast(DecimalType(13, 3))
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(22, 10)).cast(DecimalType(38, 38))
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
}

test("CreateMap casts") {
Expand All @@ -298,6 +316,17 @@ class TypeCoercionSuite extends PlanTest {
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Literal.create(null, DecimalType(5, 3)).cast(DoubleType)
:: Literal("a")
:: Literal.create(2.0, FloatType).cast(DoubleType)
:: Literal("b")
:: Nil))
// type coercion for map values
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
Expand All @@ -310,6 +339,17 @@ class TypeCoercionSuite extends PlanTest {
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
:: Literal.create(null, DecimalType(38, 38))
:: Nil),
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0)).cast(DecimalType(38, 38))
:: Literal(2)
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
// type coercion for both map keys and values
ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
Expand Down Expand Up @@ -344,6 +384,33 @@ class TypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
:: Nil),
operator(Literal(1.0).cast(DoubleType)
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
:: Literal(1).cast(DoubleType)
:: Nil))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this test does not cover the logic of handling decimal types having different precisions and scales.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will push a new commit with two more tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(null, DecimalType(15, 0))
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
:: Nil),
operator(Literal.create(null, DecimalType(15, 0)).cast(DecimalType(20, 5))
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
:: Literal(1).cast(DecimalType(20, 5))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(2L, LongType)
:: Literal(1)
:: Literal.create(null, DecimalType(10, 5))
:: Nil),
operator(Literal.create(2L, LongType).cast(DecimalType(25, 5))
:: Literal(1).cast(DecimalType(25, 5))
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(25, 5))
:: Nil))
}
}

Expand Down