Skip to content

Commit 405478b

Browse files
siyingHeartSaVioR
authored andcommitted
[SPARK-48934][SS][3.4] Python datetime types converted incorrectly for setting timeout in applyInPandasWithState
### What changes were proposed in this pull request? Fix the way applyInPandasWithState's setTimeoutTimestamp() handles argument of datetime ### Why are the changes needed? In applyInPandasWithState(), when state.setTimeoutTimestamp() is passed in with datetime.datetime type, it doesn't function as expected. Fix it. Also, fix another bug of reporting VALUE_NOT_POSITIVE. This issue will trigger when the converted value is 0. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add unit test coverage for thie scenario ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47413 from HeartSaVioR/SPARK-48934-3.4. Authored-by: Siying Dong <siying.dong@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
1 parent 20da975 commit 405478b

File tree

2 files changed

+90
-80
lines changed

2 files changed

+90
-80
lines changed

python/pyspark/sql/streaming/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
from typing import Tuple, Optional
2020

21-
from pyspark.sql.types import DateType, Row, StructType
21+
from pyspark.sql.types import Row, StructType, TimestampType
2222
from pyspark.sql.utils import has_numpy
2323

2424
__all__ = ["GroupState", "GroupStateTimeout"]
@@ -195,7 +195,7 @@ def setTimeoutTimestamp(self, timestampMs: int) -> None:
195195
)
196196

197197
if isinstance(timestampMs, datetime.datetime):
198-
timestampMs = DateType().toInternal(timestampMs)
198+
timestampMs = TimestampType().toInternal(timestampMs) / 1000
199199

200200
if timestampMs <= 0:
201201
raise ValueError("Timeout timestamp must be positive")

sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala

Lines changed: 88 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -366,91 +366,101 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
366366
)
367367
}
368368

369-
test("applyInPandasWithState - streaming w/ event time timeout + watermark") {
370-
assume(shouldTestPandasUDFs)
369+
Seq(true, false).map { ifUseDateTimeType =>
370+
test("applyInPandasWithState - streaming w/ event time timeout + watermark " +
371+
s"ifUseDateTimeType=$ifUseDateTimeType") {
372+
assume(shouldTestPandasUDFs)
371373

372-
// timestamp_seconds assumes the base timezone is UTC. However, the provided function
373-
// localizes it. Therefore, this test assumes the timezone is in UTC
374-
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
375-
val pythonScript =
376-
"""
377-
|import calendar
378-
|import os
379-
|import datetime
380-
|import pandas as pd
381-
|from pyspark.sql.types import StructType, StringType, StructField, IntegerType
382-
|
383-
|tpe = StructType([
384-
| StructField("key", StringType()),
385-
| StructField("maxEventTimeSec", IntegerType())])
386-
|
387-
|def func(key, pdf_iter, state):
388-
| assert state.getCurrentProcessingTimeMs() >= 0
389-
| assert state.getCurrentWatermarkMs() >= -1
390-
|
391-
| timeout_delay_sec = 5
392-
| if state.hasTimedOut:
393-
| state.remove()
394-
| yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]})
395-
| else:
396-
| m = state.getOption
397-
| if m is None:
398-
| max_event_time_sec = 0
399-
| else:
400-
| max_event_time_sec = m[0]
401-
|
402-
| for pdf in pdf_iter:
403-
| pser = pdf.eventTime.apply(
404-
| lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond)))
405-
| max_event_time_sec = int(max(pser.max(), max_event_time_sec))
406-
|
407-
| state.update((max_event_time_sec,))
408-
| timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec
409-
| state.setTimeoutTimestamp(timeout_timestamp_sec * 1000)
410-
| yield pd.DataFrame({'key': [key[0]],
411-
| 'maxEventTimeSec': [max_event_time_sec]})
412-
|""".stripMargin
413-
val pythonFunc = TestGroupedMapPandasUDFWithState(
414-
name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
374+
// timestamp_seconds assumes the base timezone is UTC. However, the provided function
375+
// localizes it. Therefore, this test assumes the timezone is in UTC
376+
withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
377+
val timeoutMs = if (ifUseDateTimeType) {
378+
"datetime.datetime.fromtimestamp(timeout_timestamp_sec)"
379+
} else {
380+
"timeout_timestamp_sec * 1000"
381+
}
415382

416-
val inputData = MemoryStream[(String, Int)]
417-
val inputDataDF =
418-
inputData.toDF.select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
419-
val outputStructType = StructType(
420-
Seq(
421-
StructField("key", StringType),
422-
StructField("maxEventTimeSec", IntegerType)))
423-
val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType)))
424-
val result =
425-
inputDataDF
426-
.withWatermark("eventTime", "10 seconds")
427-
.groupBy("key")
428-
.applyInPandasWithState(
429-
pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF],
430-
outputStructType,
431-
stateStructType,
432-
"Update",
433-
"EventTimeTimeout")
383+
val pythonScript =
384+
s"""
385+
|import calendar
386+
|import os
387+
|import datetime
388+
|import pandas as pd
389+
|from pyspark.sql.types import StructType, StringType, StructField, IntegerType
390+
|
391+
|tpe = StructType([
392+
| StructField("key", StringType()),
393+
| StructField("maxEventTimeSec", IntegerType())])
394+
|
395+
|def func(key, pdf_iter, state):
396+
| assert state.getCurrentProcessingTimeMs() >= 0
397+
| assert state.getCurrentWatermarkMs() >= -1
398+
|
399+
| timeout_delay_sec = 5
400+
| if state.hasTimedOut:
401+
| state.remove()
402+
| yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]})
403+
| else:
404+
| m = state.getOption
405+
| if m is None:
406+
| max_event_time_sec = 0
407+
| else:
408+
| max_event_time_sec = m[0]
409+
|
410+
| for pdf in pdf_iter:
411+
| pser = pdf.eventTime.apply(
412+
| lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond)))
413+
| max_event_time_sec = int(max(pser.max(), max_event_time_sec))
414+
|
415+
| state.update((max_event_time_sec,))
416+
| timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec
417+
| state.setTimeoutTimestamp($timeoutMs)
418+
| yield pd.DataFrame({'key': [key[0]],
419+
| 'maxEventTimeSec': [max_event_time_sec]})
420+
|""".stripMargin.format("")
421+
val pythonFunc = TestGroupedMapPandasUDFWithState(
422+
name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
434423

435-
testStream(result, Update)(
436-
StartStream(),
424+
val inputData = MemoryStream[(String, Int)]
425+
val inputDataDF =
426+
inputData.toDF().select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
427+
val outputStructType = StructType(
428+
Seq(
429+
StructField("key", StringType),
430+
StructField("maxEventTimeSec", IntegerType)))
431+
val stateStructType = StructType(Seq(StructField("maxEventTimeSec", LongType)))
432+
val result =
433+
inputDataDF
434+
.withWatermark("eventTime", "10 seconds")
435+
.groupBy("key")
436+
.applyInPandasWithState(
437+
pythonFunc(inputDataDF("key"), inputDataDF("eventTime")).expr.asInstanceOf[PythonUDF],
438+
outputStructType,
439+
stateStructType,
440+
"Update",
441+
"EventTimeTimeout")
437442

438-
AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
439-
// Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5.
440-
CheckNewAnswer(("a", 15)), // Output = max event time of a
443+
testStream(result, Update)(
444+
StartStream(),
441445

442-
AddData(inputData, ("a", 4)), // Add data older than watermark for "a"
443-
CheckNewAnswer(), // No output as data should get filtered by watermark
446+
AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
447+
// Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5.
448+
CheckNewAnswer(("a", 15)), // Output = max event time of a
444449

445-
AddData(inputData, ("a", 10)), // Add data newer than watermark for "a"
446-
CheckNewAnswer(("a", 15)), // Max event time is still the same
447-
// Timeout timestamp for "a" is still 20 as max event time for "a" is still 15.
448-
// Watermark is still 5 as max event time for all data is still 15.
450+
AddData(inputData, ("a", 4)), // Add data older than watermark for "a"
451+
CheckNewAnswer(), // No output as data should get filtered by watermark
449452

450-
AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a"
451-
// Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20.
452-
CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1
453-
)
453+
AddData(inputData, ("a", 10)), // Add data newer than watermark for "a"
454+
CheckNewAnswer(("a", 15)), // Max event time is still the same
455+
// Timeout timestamp for "a" is still 20 as max event time for "a" is still 15.
456+
// Watermark is still 5 as max event time for all data is still 15.
457+
458+
AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a"
459+
// Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is
460+
// 20.
461+
CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1
462+
)
463+
}
454464
}
455465
}
456466

0 commit comments

Comments
 (0)