Skip to content

[SPARK-39316][SQL] Merge PromotePrecision and CheckOverflow into decimal binary arithmetic #36698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 19 commits into from

Conversation

ulysses-you
Copy link
Contributor

@ulysses-you ulysses-you commented May 27, 2022

What changes were proposed in this pull request?

The main change:

  • Add a new method resultDecimalType in BinaryArithmetic
  • Add a new expression DecimalAddNoOverflowCheck for the internal decimal add, e.g. Sum/Average, the different with Add is:
    • DecimalAddNoOverflowCheck does not check overflow
    • DecimalAddNoOverflowCheck make dataType as its input parameter
  • Merge the decimal precision code of DecimalPrecision into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression PromotePrecision and related code
  • Merge CheckOverflow iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

Merge PromotePrecision into dataType, for example, Add:

override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
  val resultScale = max(s1, s2)
  if (allowPrecisionLoss) {
    DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
      resultScale)
  } else {
    DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
  }
}

Merge CheckOverflow, for example, Add eval:

dataType match {
  case decimalType: DecimalType =>
    val value = numeric.plus(input1, input2)
    checkOverflow(value.asInstanceOf[Decimal], decimalType)
  ...
}

Note that, CheckOverflow is still useful after this pr, e.g. RowEncoder. We can do further in a separate pr.

Why are the changes needed?

  • Fix the bug of TypeCoercion, for example:

    SELECT CAST(1 AS DECIMAL(28, 2))
    UNION ALL
    SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));

    The union result data type is not correct according to the formula:

    Operation Result Precision Result Scale
    e1 union e2 max(s1, s2) + max(p1-s1, p2-s2) max(s1, s2)
    -- before
    -- query schema
    decimal(28,2)
    -- query output
    1.00
    1.00
    
    -- after
    -- query schema
    decimal(38,20)
    -- query output
    1.00000000000000000000
    1.00000000000000000000
  • Relax the promote decimal precision to the runtime, so we do not need redundant Cast

Does this PR introduce any user-facing change?

yes, bug fix

How was this patch tested?

Pass exists test and add some bug fix test in decimalArithmeticOperations.sql

@github-actions github-actions bot added the SQL label May 27, 2022
@ulysses-you ulysses-you force-pushed the decimal branch 4 times, most recently from bf6c368 to 06265b1 Compare May 28, 2022 05:18
@ulysses-you
Copy link
Contributor Author

cc @cloud-fan @viirya @HyukjinKwon

@cloud-fan
Copy link
Contributor

cc @manuzhang this should fix the bug you hit

@manuzhang
Copy link
Contributor

Besides the bug fix test in union.sql, we'd better test all the DecimalArithmetic types. I don't think they've been covered before.

@@ -313,7 +379,7 @@ object BinaryArithmetic {
case class Add(
left: Expression,
right: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled) extends BinaryArithmetic {
failOnError: Boolean = SQLConf.get.ansiEnabled) extends DecimalArithmetic {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit weird. I think supporting decimal is an additional ability, and we'd better use a mix-in trait, e.g. extends BinaryArithmetic with DecimalArithmetic

trait DecimalArithmetic { self: BinaryArithmetic 

}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, better to use a trait

}

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think every subclass should implement it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is pulled out from BinaryArithmetic. I will change the method to protected.

@@ -151,106 +152,110 @@ Functions [1]: [avg(revenue#21)]
Aggregate Attributes [1]: [avg(revenue#21)#27]
Results [2]: [ss_store_sk#13, avg(revenue#21)#27 AS ave#28]

(23) BroadcastExchange
(23) Filter [codegen id : 6]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to investigate why we have an extra Filter now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'm looking at this plan change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is: both PromotePrecision and CheckOverflow do not inherit NullIntolerant, so the optimizer rule InferFiltersFromConstraints can not infer is not null from the inside attribute.

It affects all decimal binary arithmetic. This pr removes thePromotePrecision and CheckOverflow so that we have an extra filter now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, and it's actually beneficial.

left: Expression,
right: Expression,
override val dataType: DataType,
failOnError: Boolean = SQLConf.get.ansiEnabled) extends DecimalArithmetic {
Copy link
Contributor

@cloud-fan cloud-fan May 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use this failOnError parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a override filed from BinaryArithmetic

@ulysses-you
Copy link
Contributor Author

Besides the bug fix test in union.sql, we'd better test all the DecimalArithmetic types. I don't think they've been covered before.

@manuzhang yeah, will add some more tests for all decimal binary arithemtic

@@ -57,10 +53,13 @@ import org.apache.spark.sql.types._
* - LONG gets turned into DECIMAL(20, 0)
* - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
* - Literals INT and LONG get turned into DECIMAL with the precision strictly needed by the value
*
* Note that, after SPARK-39316 all binary decimal arithmetic expressions report decimal type in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding a note, can we just remove unrelated content (binary arithmetic) from the classdoc?

@@ -244,8 +311,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: DecimalType =>
// Overflow is handled in the CheckOverflow operator
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")
throw QueryExecutionErrors.unsupportedTypeError(dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should throw IllegalStateException, as it's a bug if we hit this branch.

* or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
* `ArithmeticException` is thrown.
*/
trait DecimalArithmetic extends BinaryArithmetic {
Copy link
Contributor

@cloud-fan cloud-fan May 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: SupportDecimalArithmetic

@@ -323,11 +390,27 @@ case class Add(

override def decimalMethod: String = "$plus"

// * Operation Result Precision Result Scale
// * ------------------------------------------------------------------------
// * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The formula follows Hive which is based on the SQL standard and MS SQL:
// https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
// https://msdn.microsoft.com/en-us/library/ms190476.aspx
// Result Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1
// Result Scale:     max(s1, s2)


/** Name of the function for this expression on a [[Decimal]] type. */
protected def decimalMethod: String =
throw QueryExecutionErrors.notOverrideExpectedMethodsError("DecimalArithmetic",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which sub-class does not override it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pmod

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a problem of putting override def decimalMethod: String = "remainder" in Pmod, though it's not used.

val value = ctx.freshName("value")
// scalastyle:off line.size.limit
s"""
|$javaType $value = $eval1.$decimalMethod($eval2);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ${ev.value} = $eval1.$decimalMethod($eval2);. We can assign a variable twice in Java.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just combine it

${ev.value} = $eval1.$decimalMethod($eval2).toPrecision...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, combined

private lazy val div: (Any, Any) => Any = dataType match {
case decimalType: DecimalType => (l, r) => {
val value = decimalType.fractional.asInstanceOf[Fractional[Any]].div(l, r)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation is wrong

@@ -232,3 +216,33 @@ case class CheckOverflowInSum(
override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
copy(child = newChild)
}

/**
* An add expression which is only used for internal add with decimal type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* An add expression which is only used for internal add with decimal type.
* An add expression which is only used internally by Sum/Avg

"decimalType", "dataType")

override def nullable: Boolean = dataType match {
case _: DecimalType => true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it depend on nullOnOverflow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, nullOnOverflow is more accurate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after some thought, we can not simply use nullOnOverflow here. If we want to use nullOnOverflow, the code-gen should be:

if (nullOnOverflow) {
  ${ev.isNull} = ${ev.value} == null;
}

leejaywei pushed a commit to Kyligence/spark that referenced this pull request Jul 14, 2022
…mal binary arithmetic (#481)

* [SPARK-39270][SQL] JDBC dialect supports registering dialect specific functions

The build-in functions in Spark is not the same as JDBC database.
We can provide the chance users could register dialect specific functions.

JDBC dialect supports registering dialect specific functions

'No'.
New feature.

New tests.

Closes apache#36649 from beliefer/SPARK-39270.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39413][SQL] Capitalize sql keywords in JDBCV2Suite

`JDBCV2Suite` exists some test case which uses sql keywords are not capitalized.
This PR will capitalize sql keywords in `JDBCV2Suite`.

Capitalize sql keywords in `JDBCV2Suite`.

'No'.
Just update test cases.

N/A.

Closes apache#36805 from beliefer/SPARK-39413.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: huaxingao <huaxin_gao@apple.com>

* [SPARK-38997][SPARK-39037][SQL][FOLLOWUP] PushableColumnWithoutNestedColumn` need be translated to predicate too

apache#35768 assume the expression in `And`, `Or` and `Not` must be predicate.
apache#36370 and apache#36325 supported push down expressions in `GROUP BY` and `ORDER BY`. But the children of `And`, `Or` and `Not` can be `FieldReference.column(name)`.
`FieldReference.column(name)` is not a predicate, so the assert may fail.

This PR fix the bug for `PushableColumnWithoutNestedColumn`.

'Yes'.
Let the push-down framework more correctly.

New tests

Closes apache#36776 from beliefer/SPARK-38997_SPARK-39037_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39316][SQL] Merge PromotePrecision and CheckOverflow into decimal binary arithmetic

The main change:
- Add a new method `resultDecimalType` in `BinaryArithmetic`
- Add a new expression `DecimalAddNoOverflowCheck` for the internal decimal add, e.g. `Sum`/`Average`, the different with `Add` is:
  - `DecimalAddNoOverflowCheck` does not check overflow
  - `DecimalAddNoOverflowCheck` make `dataType` as its input parameter
- Merge the decimal precision code of `DecimalPrecision` into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression `PromotePrecision` and related code
- Merge `CheckOverflow` iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

Merge `PromotePrecision` into `dataType`, for example, `Add`:
```scala
override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
  val resultScale = max(s1, s2)
  if (allowPrecisionLoss) {
    DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
      resultScale)
  } else {
    DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
  }
}
```

Merge `CheckOverflow`, for example, `Add` eval:
```scala
dataType match {
  case decimalType: DecimalType =>
    val value = numeric.plus(input1, input2)
    checkOverflow(value.asInstanceOf[Decimal], decimalType)
  ...
}
```

Note that, `CheckOverflow` is still useful after this pr, e.g. `RowEncoder`. We can do further in a separate pr.

Fix the bug of `TypeCoercion`, for example:
```sql
SELECT CAST(1 AS DECIMAL(28, 2))
UNION ALL
SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));
```

Relax the decimal precision at runtime, so we do not need redundant Cast

yes, bug fix

Pass exists test and add some bug fix test in `decimalArithmeticOperations.sql`

Closes apache#36698 from ulysses-you/decimal.

Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* fix ut

Co-authored-by: Jiaan Geng <beliefer@163.com>
Co-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
leejaywei pushed a commit to Kyligence/spark that referenced this pull request Jul 19, 2022
…mal binary arithmetic

                                                                    ### What changes were proposed in this pull request?

                                                                    The main change:
                                                                    - Add a new method `resultDecimalType` in `BinaryArithmetic`
                                                                    - Add a new expression `DecimalAddNoOverflowCheck` for the internal decimal add, e.g. `Sum`/`Average`, the different with `Add` is:
                                                                      - `DecimalAddNoOverflowCheck` does not check overflow
                                                                      - `DecimalAddNoOverflowCheck` make `dataType` as its input parameter
                                                                    - Merge the decimal precision code of `DecimalPrecision` into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression `PromotePrecision` and related code
                                                                    - Merge `CheckOverflow` iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

                                                                    Merge `PromotePrecision` into `dataType`, for example, `Add`:
                                                                    ```scala
                                                                    override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
                                                                      val resultScale = max(s1, s2)
                                                                      if (allowPrecisionLoss) {
                                                                        DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
                                                                          resultScale)
                                                                      } else {
                                                                        DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
                                                                      }
                                                                    }
                                                                    ```

                                                                    Merge `CheckOverflow`, for example, `Add` eval:
                                                                    ```scala
                                                                    dataType match {
                                                                      case decimalType: DecimalType =>
                                                                        val value = numeric.plus(input1, input2)
                                                                        checkOverflow(value.asInstanceOf[Decimal], decimalType)
                                                                      ...
                                                                    }
                                                                    ```

                                                                    Note that, `CheckOverflow` is still useful after this pr, e.g. `RowEncoder`. We can do further in a separate pr.

                                                                    ### Why are the changes needed?

                                                                    Fix the bug of `TypeCoercion`, for example:
                                                                    ```sql
                                                                    SELECT CAST(1 AS DECIMAL(28, 2))
                                                                    UNION ALL
                                                                    SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));
                                                                    ```

                                                                    Relax the decimal precision at runtime, so we do not need redundant Cast

                                                                    ### Does this PR introduce _any_ user-facing change?

                                                                    yes, bug fix

                                                                    ### How was this patch tested?

                                                                    Pass exists test and add some bug fix test in `decimalArithmeticOperations.sql`

                                                                    Closes apache#36698 from ulysses-you/decimal.

                                                                    Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
                                                                    Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
                                                                    Co-authored-by: Wenchen Fan <wenchen@databricks.com>
                                                                    Signed-off-by: Wenchen Fan <wenchen@databricks.com>

                                                                    (cherry picked from commit 301a139)

 re cherry-pick  for adapt spark3.3
zheniantoushipashi pushed a commit to Kyligence/spark that referenced this pull request Aug 8, 2022
…mal binary arithmetic (#481)

* [SPARK-39270][SQL] JDBC dialect supports registering dialect specific functions

### What changes were proposed in this pull request?
The build-in functions in Spark is not the same as JDBC database.
We can provide the chance users could register dialect specific functions.

### Why are the changes needed?
JDBC dialect supports registering dialect specific functions

### Does this PR introduce _any_ user-facing change?
'No'.
New feature.

### How was this patch tested?
New tests.

Closes apache#36649 from beliefer/SPARK-39270.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39413][SQL] Capitalize sql keywords in JDBCV2Suite

### What changes were proposed in this pull request?
`JDBCV2Suite` exists some test case which uses sql keywords are not capitalized.
This PR will capitalize sql keywords in `JDBCV2Suite`.

### Why are the changes needed?
Capitalize sql keywords in `JDBCV2Suite`.

### Does this PR introduce _any_ user-facing change?
'No'.
Just update test cases.

### How was this patch tested?
N/A.

Closes apache#36805 from beliefer/SPARK-39413.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: huaxingao <huaxin_gao@apple.com>

* [SPARK-38997][SPARK-39037][SQL][FOLLOWUP] PushableColumnWithoutNestedColumn` need be translated to predicate too

### What changes were proposed in this pull request?
apache#35768 assume the expression in `And`, `Or` and `Not` must be predicate.
apache#36370 and apache#36325 supported push down expressions in `GROUP BY` and `ORDER BY`. But the children of `And`, `Or` and `Not` can be `FieldReference.column(name)`.
`FieldReference.column(name)` is not a predicate, so the assert may fail.

### Why are the changes needed?
This PR fix the bug for `PushableColumnWithoutNestedColumn`.

### Does this PR introduce _any_ user-facing change?
'Yes'.
Let the push-down framework more correctly.

### How was this patch tested?
New tests

Closes apache#36776 from beliefer/SPARK-38997_SPARK-39037_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39316][SQL] Merge PromotePrecision and CheckOverflow into decimal binary arithmetic

### What changes were proposed in this pull request?

The main change:
- Add a new method `resultDecimalType` in `BinaryArithmetic`
- Add a new expression `DecimalAddNoOverflowCheck` for the internal decimal add, e.g. `Sum`/`Average`, the different with `Add` is:
  - `DecimalAddNoOverflowCheck` does not check overflow
  - `DecimalAddNoOverflowCheck` make `dataType` as its input parameter
- Merge the decimal precision code of `DecimalPrecision` into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression `PromotePrecision` and related code
- Merge `CheckOverflow` iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

Merge `PromotePrecision` into `dataType`, for example, `Add`:
```scala
override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
  val resultScale = max(s1, s2)
  if (allowPrecisionLoss) {
    DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
      resultScale)
  } else {
    DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
  }
}
```

Merge `CheckOverflow`, for example, `Add` eval:
```scala
dataType match {
  case decimalType: DecimalType =>
    val value = numeric.plus(input1, input2)
    checkOverflow(value.asInstanceOf[Decimal], decimalType)
  ...
}
```

Note that, `CheckOverflow` is still useful after this pr, e.g. `RowEncoder`. We can do further in a separate pr.

### Why are the changes needed?

Fix the bug of `TypeCoercion`, for example:
```sql
SELECT CAST(1 AS DECIMAL(28, 2))
UNION ALL
SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));
```

Relax the decimal precision at runtime, so we do not need redundant Cast

### Does this PR introduce _any_ user-facing change?

yes, bug fix

### How was this patch tested?

Pass exists test and add some bug fix test in `decimalArithmeticOperations.sql`

Closes apache#36698 from ulysses-you/decimal.

Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* fix ut

Co-authored-by: Jiaan Geng <beliefer@163.com>
Co-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
@gengliangwang
Copy link
Member

@ulysses-you Is the following query an actual bug before the refactor? Or did the refactor just remove the redundant cast?

SELECT CAST(1 AS DECIMAL(28, 2))
UNION ALL
SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));

@ulysses-you
Copy link
Contributor Author

@gengliangwang it is a bug fix and also have improvement for saving unnecessary cast. The query will produce the unexpected precision and scale. before: decimal(28,2), after: decimal(38,20)

cloud-fan pushed a commit that referenced this pull request Jan 31, 2023
### What changes were proposed in this pull request?

0 is a special case for decimal which data type can be Decimal(0, 0), to be safe we should use decimal(1, 0) to represent 0.

### Why are the changes needed?

fix data correctness for regression.

We do not promote the decimal precision since we refactor decimal binary operater in #36698. However, it causes the intermediate decimal type of `IntegralDivide` returns decimal(0, 0). It's dangerous that Spark does not actually support decimal(0, 0). e.g.
```sql
-- work with in-memory catalog
create table t (c decimal(0, 0)) using parquet;
-- fail with parquet
-- java.lang.IllegalArgumentException: Invalid DECIMAL precision: 0
--	at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:57)
insert into table t values(0);

-- fail with hive catalog
-- Caused by: java.lang.IllegalArgumentException: Decimal precision out of allowed range [1,38]
--	at org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils.validateParameter(HiveDecimalUtils.java:44)
create table t (c decimal(0, 0)) using parquet;
```
And decimal(0, 0) means the data is 0, so to be safe we use decimal(1, 0) to represent 0 for `IntegralDivide`.

### Does this PR introduce _any_ user-facing change?

yes, bug fix

### How was this patch tested?

add test

Closes #38760 from ulysses-you/SPARK-41219.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
cloud-fan pushed a commit that referenced this pull request Jan 31, 2023
### What changes were proposed in this pull request?

0 is a special case for decimal which data type can be Decimal(0, 0), to be safe we should use decimal(1, 0) to represent 0.

### Why are the changes needed?

fix data correctness for regression.

We do not promote the decimal precision since we refactor decimal binary operater in #36698. However, it causes the intermediate decimal type of `IntegralDivide` returns decimal(0, 0). It's dangerous that Spark does not actually support decimal(0, 0). e.g.
```sql
-- work with in-memory catalog
create table t (c decimal(0, 0)) using parquet;
-- fail with parquet
-- java.lang.IllegalArgumentException: Invalid DECIMAL precision: 0
--	at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:57)
insert into table t values(0);

-- fail with hive catalog
-- Caused by: java.lang.IllegalArgumentException: Decimal precision out of allowed range [1,38]
--	at org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils.validateParameter(HiveDecimalUtils.java:44)
create table t (c decimal(0, 0)) using parquet;
```
And decimal(0, 0) means the data is 0, so to be safe we use decimal(1, 0) to represent 0 for `IntegralDivide`.

### Does this PR introduce _any_ user-facing change?

yes, bug fix

### How was this patch tested?

add test

Closes #38760 from ulysses-you/SPARK-41219.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit a056f69)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
cloud-fan pushed a commit that referenced this pull request Feb 23, 2023
…nge clause on large decimals

### What changes were proposed in this pull request?

Use `DecimalAddNoOverflowCheck` instead of `Add` to craete bound ordering for window range frame

### Why are the changes needed?

Before 3.4, the `Add` did not check overflow. Instead, we always wrapped `Add` with a `CheckOverflow`. After #36698, we make `Add` check overflow by itself. However, the bound ordering of window range frame uses `Add` to calculate the boundary that is used to determine which input row lies within the frame boundaries of an output row. Then the behavior is changed with an extra overflow check.

Technically,We could allow an overflowing value if it is just an intermediate result. So this pr use `DecimalAddNoOverflowCheck` to replace the `Add` to restore the previous behavior.

### Does this PR introduce _any_ user-facing change?

yes, restore the previous(before 3.4) behavior

### How was this patch tested?

add test

Closes #40138 from ulysses-you/SPARK-41793.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
cloud-fan pushed a commit that referenced this pull request Feb 23, 2023
…nge clause on large decimals

### What changes were proposed in this pull request?

Use `DecimalAddNoOverflowCheck` instead of `Add` to craete bound ordering for window range frame

### Why are the changes needed?

Before 3.4, the `Add` did not check overflow. Instead, we always wrapped `Add` with a `CheckOverflow`. After #36698, we make `Add` check overflow by itself. However, the bound ordering of window range frame uses `Add` to calculate the boundary that is used to determine which input row lies within the frame boundaries of an output row. Then the behavior is changed with an extra overflow check.

Technically,We could allow an overflowing value if it is just an intermediate result. So this pr use `DecimalAddNoOverflowCheck` to replace the `Add` to restore the previous behavior.

### Does this PR introduce _any_ user-facing change?

yes, restore the previous(before 3.4) behavior

### How was this patch tested?

add test

Closes #40138 from ulysses-you/SPARK-41793.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit fec4f7f)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
@viirya
Copy link
Member

viirya commented Mar 1, 2023

Looking at this motivated by something related. But what the bug is not so clear in the PR description and the JIRA, I only see that it claims there is a bug, but no clear description about what the bug actually is and what is correct behavior.

Is it possible to update the PR description and JIRA (even they were merged/closed)? It's still valuable for others for tracking purpose.

@ulysses-you
Copy link
Contributor Author

@viirya sure, I have updated the pr description and jira. Hope it is more clear now.

@viirya
Copy link
Member

viirya commented Mar 2, 2023

Thank you @ulysses-you

snmvaughan pushed a commit to snmvaughan/spark that referenced this pull request Jun 20, 2023
### What changes were proposed in this pull request?

0 is a special case for decimal which data type can be Decimal(0, 0), to be safe we should use decimal(1, 0) to represent 0.

### Why are the changes needed?

fix data correctness for regression.

We do not promote the decimal precision since we refactor decimal binary operater in apache#36698. However, it causes the intermediate decimal type of `IntegralDivide` returns decimal(0, 0). It's dangerous that Spark does not actually support decimal(0, 0). e.g.
```sql
-- work with in-memory catalog
create table t (c decimal(0, 0)) using parquet;
-- fail with parquet
-- java.lang.IllegalArgumentException: Invalid DECIMAL precision: 0
--	at org.apache.parquet.Preconditions.checkArgument(Preconditions.java:57)
insert into table t values(0);

-- fail with hive catalog
-- Caused by: java.lang.IllegalArgumentException: Decimal precision out of allowed range [1,38]
--	at org.apache.hadoop.hive.serde2.typeinfo.HiveDecimalUtils.validateParameter(HiveDecimalUtils.java:44)
create table t (c decimal(0, 0)) using parquet;
```
And decimal(0, 0) means the data is 0, so to be safe we use decimal(1, 0) to represent 0 for `IntegralDivide`.

### Does this PR introduce _any_ user-facing change?

yes, bug fix

### How was this patch tested?

add test

Closes apache#38760 from ulysses-you/SPARK-41219.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit a056f69)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
snmvaughan pushed a commit to snmvaughan/spark that referenced this pull request Jun 20, 2023
…nge clause on large decimals

### What changes were proposed in this pull request?

Use `DecimalAddNoOverflowCheck` instead of `Add` to craete bound ordering for window range frame

### Why are the changes needed?

Before 3.4, the `Add` did not check overflow. Instead, we always wrapped `Add` with a `CheckOverflow`. After apache#36698, we make `Add` check overflow by itself. However, the bound ordering of window range frame uses `Add` to calculate the boundary that is used to determine which input row lies within the frame boundaries of an output row. Then the behavior is changed with an extra overflow check.

Technically,We could allow an overflowing value if it is just an intermediate result. So this pr use `DecimalAddNoOverflowCheck` to replace the `Add` to restore the previous behavior.

### Does this PR introduce _any_ user-facing change?

yes, restore the previous(before 3.4) behavior

### How was this patch tested?

add test

Closes apache#40138 from ulysses-you/SPARK-41793.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit fec4f7f)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
RolatZhang pushed a commit to Kyligence/spark that referenced this pull request Aug 29, 2023
…mal binary arithmetic (#481)

* [SPARK-39270][SQL] JDBC dialect supports registering dialect specific functions

The build-in functions in Spark is not the same as JDBC database.
We can provide the chance users could register dialect specific functions.

JDBC dialect supports registering dialect specific functions

'No'.
New feature.

New tests.

Closes apache#36649 from beliefer/SPARK-39270.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39413][SQL] Capitalize sql keywords in JDBCV2Suite

`JDBCV2Suite` exists some test case which uses sql keywords are not capitalized.
This PR will capitalize sql keywords in `JDBCV2Suite`.

Capitalize sql keywords in `JDBCV2Suite`.

'No'.
Just update test cases.

N/A.

Closes apache#36805 from beliefer/SPARK-39413.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: huaxingao <huaxin_gao@apple.com>

* [SPARK-38997][SPARK-39037][SQL][FOLLOWUP] PushableColumnWithoutNestedColumn` need be translated to predicate too

apache#35768 assume the expression in `And`, `Or` and `Not` must be predicate.
apache#36370 and apache#36325 supported push down expressions in `GROUP BY` and `ORDER BY`. But the children of `And`, `Or` and `Not` can be `FieldReference.column(name)`.
`FieldReference.column(name)` is not a predicate, so the assert may fail.

This PR fix the bug for `PushableColumnWithoutNestedColumn`.

'Yes'.
Let the push-down framework more correctly.

New tests

Closes apache#36776 from beliefer/SPARK-38997_SPARK-39037_followup.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* [SPARK-39316][SQL] Merge PromotePrecision and CheckOverflow into decimal binary arithmetic

The main change:
- Add a new method `resultDecimalType` in `BinaryArithmetic`
- Add a new expression `DecimalAddNoOverflowCheck` for the internal decimal add, e.g. `Sum`/`Average`, the different with `Add` is:
  - `DecimalAddNoOverflowCheck` does not check overflow
  - `DecimalAddNoOverflowCheck` make `dataType` as its input parameter
- Merge the decimal precision code of `DecimalPrecision` into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression `PromotePrecision` and related code
- Merge `CheckOverflow` iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

Merge `PromotePrecision` into `dataType`, for example, `Add`:
```scala
override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
  val resultScale = max(s1, s2)
  if (allowPrecisionLoss) {
    DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
      resultScale)
  } else {
    DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
  }
}
```

Merge `CheckOverflow`, for example, `Add` eval:
```scala
dataType match {
  case decimalType: DecimalType =>
    val value = numeric.plus(input1, input2)
    checkOverflow(value.asInstanceOf[Decimal], decimalType)
  ...
}
```

Note that, `CheckOverflow` is still useful after this pr, e.g. `RowEncoder`. We can do further in a separate pr.

Fix the bug of `TypeCoercion`, for example:
```sql
SELECT CAST(1 AS DECIMAL(28, 2))
UNION ALL
SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));
```

Relax the decimal precision at runtime, so we do not need redundant Cast

yes, bug fix

Pass exists test and add some bug fix test in `decimalArithmeticOperations.sql`

Closes apache#36698 from ulysses-you/decimal.

Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>

* fix ut

Co-authored-by: Jiaan Geng <beliefer@163.com>
Co-authored-by: ulysses-you <ulyssesyou18@gmail.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Co-authored-by: Wenchen Fan <wenchen@databricks.com>
RolatZhang pushed a commit to Kyligence/spark that referenced this pull request Aug 29, 2023
…mal binary arithmetic

                                                                    ### What changes were proposed in this pull request?

                                                                    The main change:
                                                                    - Add a new method `resultDecimalType` in `BinaryArithmetic`
                                                                    - Add a new expression `DecimalAddNoOverflowCheck` for the internal decimal add, e.g. `Sum`/`Average`, the different with `Add` is:
                                                                      - `DecimalAddNoOverflowCheck` does not check overflow
                                                                      - `DecimalAddNoOverflowCheck` make `dataType` as its input parameter
                                                                    - Merge the decimal precision code of `DecimalPrecision` into each arithmetic data type, so every arithmetic should report the accurate decimal type. And we can remove the unused expression `PromotePrecision` and related code
                                                                    - Merge `CheckOverflow` iinto arithmetic eval and code-gen code path, so every arithmetic can handle the overflow case during runtime

                                                                    Merge `PromotePrecision` into `dataType`, for example, `Add`:
                                                                    ```scala
                                                                    override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
                                                                      val resultScale = max(s1, s2)
                                                                      if (allowPrecisionLoss) {
                                                                        DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1,
                                                                          resultScale)
                                                                      } else {
                                                                        DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
                                                                      }
                                                                    }
                                                                    ```

                                                                    Merge `CheckOverflow`, for example, `Add` eval:
                                                                    ```scala
                                                                    dataType match {
                                                                      case decimalType: DecimalType =>
                                                                        val value = numeric.plus(input1, input2)
                                                                        checkOverflow(value.asInstanceOf[Decimal], decimalType)
                                                                      ...
                                                                    }
                                                                    ```

                                                                    Note that, `CheckOverflow` is still useful after this pr, e.g. `RowEncoder`. We can do further in a separate pr.

                                                                    ### Why are the changes needed?

                                                                    Fix the bug of `TypeCoercion`, for example:
                                                                    ```sql
                                                                    SELECT CAST(1 AS DECIMAL(28, 2))
                                                                    UNION ALL
                                                                    SELECT CAST(1 AS DECIMAL(18, 2)) / CAST(1 AS DECIMAL(18, 2));
                                                                    ```

                                                                    Relax the decimal precision at runtime, so we do not need redundant Cast

                                                                    ### Does this PR introduce _any_ user-facing change?

                                                                    yes, bug fix

                                                                    ### How was this patch tested?

                                                                    Pass exists test and add some bug fix test in `decimalArithmeticOperations.sql`

                                                                    Closes apache#36698 from ulysses-you/decimal.

                                                                    Lead-authored-by: ulysses-you <ulyssesyou18@gmail.com>
                                                                    Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
                                                                    Co-authored-by: Wenchen Fan <wenchen@databricks.com>
                                                                    Signed-off-by: Wenchen Fan <wenchen@databricks.com>

                                                                    (cherry picked from commit 301a139)

 re cherry-pick  for adapt spark3.3
wangyum pushed a commit that referenced this pull request Sep 6, 2023
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
wangyum pushed a commit that referenced this pull request Sep 6, 2023
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
(cherry picked from commit 16e813c)
Signed-off-by: Yuming Wang <yumwang@ebay.com>
wangyum pushed a commit that referenced this pull request Sep 6, 2023
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes #42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
(cherry picked from commit 16e813c)
Signed-off-by: Yuming Wang <yumwang@ebay.com>
viirya pushed a commit to viirya/spark-1 that referenced this pull request Oct 19, 2023
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](apache#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
(cherry picked from commit 16e813c)
Signed-off-by: Yuming Wang <yumwang@ebay.com>
(cherry picked from commit a96804b)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
zml1206 pushed a commit to zml1206/spark that referenced this pull request May 7, 2025
…#dataType` when processing multi-column data

### What changes were proposed in this pull request?

Since `BinaryArithmetic#dataType` will recursively process the datatype of each node, the driver will be very slow when multiple columns are processed.

For example, the following code:
```scala
import spark.implicits._
import scala.util.Random
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{StructType, StructField, IntegerType}

val N = 30
val M = 100

val columns = Seq.fill(N)(Random.alphanumeric.take(8).mkString)
val data = Seq.fill(M)(Seq.fill(N)(Random.nextInt(16) - 5))

val schema = StructType(columns.map(StructField(_, IntegerType)))
val rdd = spark.sparkContext.parallelize(data.map(Row.fromSeq(_)))
val df = spark.createDataFrame(rdd, schema)
val colExprs = columns.map(sum(_))

// gen a new column , and add the other 30 column
df.withColumn("new_col_sum", expr(columns.mkString(" + ")))
```

This code will take a few minutes for the driver to execute in the spark3.4 version, but only takes a few seconds to execute in the spark3.2 version. Related issue: [SPARK-39316](apache#36698)

### Why are the changes needed?

Optimize the processing speed of `BinaryArithmetic#dataType` when processing multi-column data

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

manual testing

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#42804 from zzzzming95/SPARK-45071.

Authored-by: zzzzming95 <505306252@qq.com>
Signed-off-by: Yuming Wang <yumwang@ebay.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants