Skip to content

Commit 9d3d25b

Browse files
MaxGekkcloud-fan
authored andcommitted
[SPARK-34677][SQL] Support the +/- operators over ANSI SQL intervals
### What changes were proposed in this pull request? Extend the `Add`, `Subtract` and `UnaryMinus` expression to support `DayTimeIntervalType` and `YearMonthIntervalType` added by #31614. Note: the expressions can throw the `overflow` exception independently from the SQL config `spark.sql.ansi.enabled`. In this way, the modified expressions always behave in the ANSI mode for the intervals. ### Why are the changes needed? To conform to the ANSI SQL standard which defines `-/+` over intervals: <img width="822" alt="Screenshot 2021-03-09 at 21 59 22" src="https://user-images.githubusercontent.com/1580697/110523128-bd50ea80-8122-11eb-9982-782da0088d27.png"> ### Does this PR introduce _any_ user-facing change? Should not since new types have not been released yet. ### How was this patch tested? By running new tests in the test suites: ``` $ build/sbt "test:testOnly *ArithmeticExpressionSuite" $ build/sbt "test:testOnly *ColumnExpressionSuite" ``` Closes #31789 from MaxGekk/add-subtruct-intervals. Authored-by: Max Gekk <max.gekk@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 5c4d8f9 commit 9d3d25b

File tree

9 files changed

+106
-25
lines changed

9 files changed

+106
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,19 @@ case class UnaryMinus(
8383
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
8484
val method = if (failOnError) "negateExact" else "negate"
8585
defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
86+
case DayTimeIntervalType | YearMonthIntervalType =>
87+
nullSafeCodeGen(ctx, ev, eval => {
88+
val mathClass = classOf[Math].getName
89+
s"${ev.value} = $mathClass.negateExact($eval);"
90+
})
8691
}
8792

8893
protected override def nullSafeEval(input: Any): Any = dataType match {
8994
case CalendarIntervalType if failOnError =>
9095
IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
9196
case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
97+
case DayTimeIntervalType => Math.negateExact(input.asInstanceOf[Long])
98+
case YearMonthIntervalType => Math.negateExact(input.asInstanceOf[Int])
9299
case _ => numeric.negate(input)
93100
}
94101

@@ -185,6 +192,12 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
185192
case CalendarIntervalType =>
186193
val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
187194
defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
195+
case DayTimeIntervalType | YearMonthIntervalType =>
196+
assert(exactMathMethod.isDefined,
197+
s"The expression '$nodeName' must override the exactMathMethod() method " +
198+
"if it is supposed to operate over interval types.")
199+
val mathClass = classOf[Math].getName
200+
defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathClass.${exactMathMethod.get}($eval1, $eval2)")
188201
// byte and short are casted into int when add, minus, times or divide
189202
case ByteType | ShortType =>
190203
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
@@ -267,6 +280,10 @@ case class Add(
267280
case CalendarIntervalType =>
268281
IntervalUtils.add(
269282
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
283+
case DayTimeIntervalType =>
284+
Math.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
285+
case YearMonthIntervalType =>
286+
Math.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
270287
case _ => numeric.plus(input1, input2)
271288
}
272289

@@ -306,6 +323,10 @@ case class Subtract(
306323
case CalendarIntervalType =>
307324
IntervalUtils.subtract(
308325
input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
326+
case DayTimeIntervalType =>
327+
Math.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
328+
case YearMonthIntervalType =>
329+
Math.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
309330
case _ => numeric.minus(input1, input2)
310331
}
311332

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,11 @@ private[sql] object TypeCollection {
8282
* Types that include numeric types and interval type. They are only used in unary_minus,
8383
* unary_positive, add and subtract operations.
8484
*/
85-
val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType)
85+
val NumericAndInterval = TypeCollection(
86+
NumericType,
87+
CalendarIntervalType,
88+
DayTimeIntervalType,
89+
YearMonthIntervalType)
8690

8791
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
8892

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
7878
assertErrorForDifferingTypes(BitwiseXor(Symbol("intField"), Symbol("booleanField")))
7979

8080
assertError(Add(Symbol("booleanField"), Symbol("booleanField")),
81-
"requires (numeric or interval) type")
81+
"requires (numeric or interval or daytimeinterval or yearmonthinterval) type")
8282
assertError(Subtract(Symbol("booleanField"), Symbol("booleanField")),
83-
"requires (numeric or interval) type")
83+
"requires (numeric or interval or daytimeinterval or yearmonthinterval) type")
8484
assertError(Multiply(Symbol("booleanField"), Symbol("booleanField")), "requires numeric type")
8585
assertError(Divide(Symbol("booleanField"), Symbol("booleanField")),
8686
"requires (double or decimal) type")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import java.sql.{Date, Timestamp}
21+
import java.time.{Duration, Period}
2122

2223
import org.apache.spark.SparkFunSuite
2324
import org.apache.spark.sql.catalyst.InternalRow
@@ -576,4 +577,43 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
576577
}
577578
}
578579
}
580+
581+
test("SPARK-34677: exact add and subtract of day-time and year-month intervals") {
582+
Seq(true, false).foreach { failOnError =>
583+
checkExceptionInExpression[ArithmeticException](
584+
UnaryMinus(
585+
Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType),
586+
failOnError),
587+
"overflow")
588+
checkExceptionInExpression[ArithmeticException](
589+
Subtract(
590+
Literal.create(Period.ofMonths(Int.MinValue), YearMonthIntervalType),
591+
Literal.create(Period.ofMonths(10), YearMonthIntervalType),
592+
failOnError
593+
),
594+
"overflow")
595+
checkExceptionInExpression[ArithmeticException](
596+
Add(
597+
Literal.create(Period.ofMonths(Int.MaxValue), YearMonthIntervalType),
598+
Literal.create(Period.ofMonths(10), YearMonthIntervalType),
599+
failOnError
600+
),
601+
"overflow")
602+
603+
checkExceptionInExpression[ArithmeticException](
604+
Subtract(
605+
Literal.create(Duration.ofDays(-106751991), DayTimeIntervalType),
606+
Literal.create(Duration.ofDays(10), DayTimeIntervalType),
607+
failOnError
608+
),
609+
"overflow")
610+
checkExceptionInExpression[ArithmeticException](
611+
Add(
612+
Literal.create(Duration.ofDays(106751991), DayTimeIntervalType),
613+
Literal.create(Duration.ofDays(10), DayTimeIntervalType),
614+
failOnError
615+
),
616+
"overflow")
617+
}
618+
}
579619
}

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ object DataTypeTestUtils {
5252
/**
5353
* Instances of all [[NumericType]]s and [[CalendarIntervalType]]
5454
*/
55-
val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal + CalendarIntervalType
55+
val numericAndInterval: Set[DataType] = numericTypeWithoutDecimal ++ Set(
56+
CalendarIntervalType,
57+
DayTimeIntervalType,
58+
YearMonthIntervalType)
5659

5760
/**
5861
* All the types that support ordering

sql/core/src/test/resources/sql-tests/results/ansi/literals.sql.out

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ select +date '1999-01-01'
436436
struct<>
437437
-- !query output
438438
org.apache.spark.sql.AnalysisException
439-
cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
439+
cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
440440

441441

442442
-- !query
@@ -445,7 +445,7 @@ select +timestamp '1999-01-01'
445445
struct<>
446446
-- !query output
447447
org.apache.spark.sql.AnalysisException
448-
cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
448+
cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
449449

450450

451451
-- !query
@@ -462,7 +462,7 @@ select +map(1, 2)
462462
struct<>
463463
-- !query output
464464
org.apache.spark.sql.AnalysisException
465-
cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map<int,int> type.; line 1 pos 7
465+
cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map<int,int> type.; line 1 pos 7
466466

467467

468468
-- !query
@@ -471,7 +471,7 @@ select +array(1,2)
471471
struct<>
472472
-- !query output
473473
org.apache.spark.sql.AnalysisException
474-
cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array<int> type.; line 1 pos 7
474+
cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array<int> type.; line 1 pos 7
475475

476476

477477
-- !query
@@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark')
480480
struct<>
481481
-- !query output
482482
org.apache.spark.sql.AnalysisException
483-
cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct<a:int,b:string> type.; line 1 pos 7
483+
cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct<a:int,b:string> type.; line 1 pos 7
484484

485485

486486
-- !query
@@ -489,7 +489,7 @@ select +X'1'
489489
struct<>
490490
-- !query output
491491
org.apache.spark.sql.AnalysisException
492-
cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7
492+
cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7
493493

494494

495495
-- !query
@@ -498,7 +498,7 @@ select -date '1999-01-01'
498498
struct<>
499499
-- !query output
500500
org.apache.spark.sql.AnalysisException
501-
cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
501+
cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
502502

503503

504504
-- !query
@@ -507,7 +507,7 @@ select -timestamp '1999-01-01'
507507
struct<>
508508
-- !query output
509509
org.apache.spark.sql.AnalysisException
510-
cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
510+
cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
511511

512512

513513
-- !query
@@ -516,4 +516,4 @@ select -x'2379ACFe'
516516
struct<>
517517
-- !query output
518518
org.apache.spark.sql.AnalysisException
519-
cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7
519+
cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7

sql/core/src/test/resources/sql-tests/results/literals.sql.out

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ select +date '1999-01-01'
436436
struct<>
437437
-- !query output
438438
org.apache.spark.sql.AnalysisException
439-
cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
439+
cannot resolve '(+ DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
440440

441441

442442
-- !query
@@ -445,7 +445,7 @@ select +timestamp '1999-01-01'
445445
struct<>
446446
-- !query output
447447
org.apache.spark.sql.AnalysisException
448-
cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
448+
cannot resolve '(+ TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
449449

450450

451451
-- !query
@@ -462,7 +462,7 @@ select +map(1, 2)
462462
struct<>
463463
-- !query output
464464
org.apache.spark.sql.AnalysisException
465-
cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'map(1, 2)' is of map<int,int> type.; line 1 pos 7
465+
cannot resolve '(+ map(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'map(1, 2)' is of map<int,int> type.; line 1 pos 7
466466

467467

468468
-- !query
@@ -471,7 +471,7 @@ select +array(1,2)
471471
struct<>
472472
-- !query output
473473
org.apache.spark.sql.AnalysisException
474-
cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'array(1, 2)' is of array<int> type.; line 1 pos 7
474+
cannot resolve '(+ array(1, 2))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'array(1, 2)' is of array<int> type.; line 1 pos 7
475475

476476

477477
-- !query
@@ -480,7 +480,7 @@ select +named_struct('a', 1, 'b', 'spark')
480480
struct<>
481481
-- !query output
482482
org.apache.spark.sql.AnalysisException
483-
cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct<a:int,b:string> type.; line 1 pos 7
483+
cannot resolve '(+ named_struct('a', 1, 'b', 'spark'))' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'named_struct('a', 1, 'b', 'spark')' is of struct<a:int,b:string> type.; line 1 pos 7
484484

485485

486486
-- !query
@@ -489,7 +489,7 @@ select +X'1'
489489
struct<>
490490
-- !query output
491491
org.apache.spark.sql.AnalysisException
492-
cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'01'' is of binary type.; line 1 pos 7
492+
cannot resolve '(+ X'01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'01'' is of binary type.; line 1 pos 7
493493

494494

495495
-- !query
@@ -498,7 +498,7 @@ select -date '1999-01-01'
498498
struct<>
499499
-- !query output
500500
org.apache.spark.sql.AnalysisException
501-
cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
501+
cannot resolve '(- DATE '1999-01-01')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'DATE '1999-01-01'' is of date type.; line 1 pos 7
502502

503503

504504
-- !query
@@ -507,7 +507,7 @@ select -timestamp '1999-01-01'
507507
struct<>
508508
-- !query output
509509
org.apache.spark.sql.AnalysisException
510-
cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
510+
cannot resolve '(- TIMESTAMP '1999-01-01 00:00:00')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'TIMESTAMP '1999-01-01 00:00:00'' is of timestamp type.; line 1 pos 7
511511

512512

513513
-- !query
@@ -516,4 +516,4 @@ select -x'2379ACFe'
516516
struct<>
517517
-- !query output
518518
org.apache.spark.sql.AnalysisException
519-
cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7
519+
cannot resolve '(- X'2379ACFE')' due to data type mismatch: argument 1 requires (numeric or interval or daytimeinterval or yearmonthinterval) type, however, 'X'2379ACFE'' is of binary type.; line 1 pos 7

sql/core/src/test/resources/sql-tests/results/typeCoercion/native/windowFrameCoercion.sql.out

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as string) DESC RANGE BETWE
168168
struct<>
169169
-- !query output
170170
org.apache.spark.sql.AnalysisException
171-
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval)'.; line 1 pos 21
171+
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS STRING) FOLLOWING' due to data type mismatch: The data type of the upper bound 'string' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21
172172

173173

174174
-- !query
@@ -177,7 +177,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast('1' as binary) DESC RANGE BET
177177
struct<>
178178
-- !query output
179179
org.apache.spark.sql.AnalysisException
180-
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval)'.; line 1 pos 21
180+
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BINARY) FOLLOWING' due to data type mismatch: The data type of the upper bound 'binary' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21
181181

182182

183183
-- !query
@@ -186,7 +186,7 @@ SELECT COUNT(*) OVER (PARTITION BY 1 ORDER BY cast(1 as boolean) DESC RANGE BETW
186186
struct<>
187187
-- !query output
188188
org.apache.spark.sql.AnalysisException
189-
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval)'.; line 1 pos 21
189+
cannot resolve 'RANGE BETWEEN CURRENT ROW AND CAST(1 AS BOOLEAN) FOLLOWING' due to data type mismatch: The data type of the upper bound 'boolean' does not match the expected data type '(numeric or interval or daytimeinterval or yearmonthinterval)'.; line 1 pos 21
190190

191191

192192
-- !query

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import java.sql.{Date, Timestamp}
21+
import java.time.{Duration, Period}
2122
import java.util.Locale
2223

2324
import org.apache.hadoop.io.{LongWritable, Text}
@@ -2375,4 +2376,16 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
23752376
assert(e2.getCause.isInstanceOf[RuntimeException])
23762377
assert(e2.getCause.getMessage == "hello")
23772378
}
2379+
2380+
test("SPARK-34677: negate/add/subtract year-month and day-time intervals") {
2381+
import testImplicits._
2382+
val df = Seq((Period.ofMonths(10), Duration.ofDays(10), Period.ofMonths(1), Duration.ofDays(1)))
2383+
.toDF("year-month-A", "day-time-A", "year-month-B", "day-time-B")
2384+
val negatedDF = df.select(-$"year-month-A", -$"day-time-A")
2385+
checkAnswer(negatedDF, Row(Period.ofMonths(-10), Duration.ofDays(-10)))
2386+
val addDF = df.select($"year-month-A" + $"year-month-B", $"day-time-A" + $"day-time-B")
2387+
checkAnswer(addDF, Row(Period.ofMonths(11), Duration.ofDays(11)))
2388+
val subDF = df.select($"year-month-A" - $"year-month-B", $"day-time-A" - $"day-time-B")
2389+
checkAnswer(subDF, Row(Period.ofMonths(9), Duration.ofDays(9)))
2390+
}
23782391
}

0 commit comments

Comments
 (0)