@@ -366,91 +366,101 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
366366 )
367367 }
368368
369- test(" applyInPandasWithState - streaming w/ event time timeout + watermark" ) {
370- assume(shouldTestPandasUDFs)
369+ Seq (true , false ).map { ifUseDateTimeType =>
370+ test(" applyInPandasWithState - streaming w/ event time timeout + watermark " +
371+ s " ifUseDateTimeType= $ifUseDateTimeType" ) {
372+ assume(shouldTestPandasUDFs)
371373
372- // timestamp_seconds assumes the base timezone is UTC. However, the provided function
373- // localizes it. Therefore, this test assumes the timezone is in UTC
374- withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> " UTC" ) {
375- val pythonScript =
376- """
377- |import calendar
378- |import os
379- |import datetime
380- |import pandas as pd
381- |from pyspark.sql.types import StructType, StringType, StructField, IntegerType
382- |
383- |tpe = StructType([
384- | StructField("key", StringType()),
385- | StructField("maxEventTimeSec", IntegerType())])
386- |
387- |def func(key, pdf_iter, state):
388- | assert state.getCurrentProcessingTimeMs() >= 0
389- | assert state.getCurrentWatermarkMs() >= -1
390- |
391- | timeout_delay_sec = 5
392- | if state.hasTimedOut:
393- | state.remove()
394- | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]})
395- | else:
396- | m = state.getOption
397- | if m is None:
398- | max_event_time_sec = 0
399- | else:
400- | max_event_time_sec = m[0]
401- |
402- | for pdf in pdf_iter:
403- | pser = pdf.eventTime.apply(
404- | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond)))
405- | max_event_time_sec = int(max(pser.max(), max_event_time_sec))
406- |
407- | state.update((max_event_time_sec,))
408- | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec
409- | state.setTimeoutTimestamp(timeout_timestamp_sec * 1000)
410- | yield pd.DataFrame({'key': [key[0]],
411- | 'maxEventTimeSec': [max_event_time_sec]})
412- |""" .stripMargin
413- val pythonFunc = TestGroupedMapPandasUDFWithState (
414- name = " pandas_grouped_map_with_state" , pythonScript = pythonScript)
374+ // timestamp_seconds assumes the base timezone is UTC. However, the provided function
375+ // localizes it. Therefore, this test assumes the timezone is in UTC
376+ withSQLConf(SQLConf .SESSION_LOCAL_TIMEZONE .key -> " UTC" ) {
377+ val timeoutMs = if (ifUseDateTimeType) {
378+ " datetime.datetime.fromtimestamp(timeout_timestamp_sec)"
379+ } else {
380+ " timeout_timestamp_sec * 1000"
381+ }
415382
416- val inputData = MemoryStream [(String , Int )]
417- val inputDataDF =
418- inputData.toDF.select($" _1" .as(" key" ), timestamp_seconds($" _2" ).as(" eventTime" ))
419- val outputStructType = StructType (
420- Seq (
421- StructField (" key" , StringType ),
422- StructField (" maxEventTimeSec" , IntegerType )))
423- val stateStructType = StructType (Seq (StructField (" maxEventTimeSec" , LongType )))
424- val result =
425- inputDataDF
426- .withWatermark(" eventTime" , " 10 seconds" )
427- .groupBy(" key" )
428- .applyInPandasWithState(
429- pythonFunc(inputDataDF(" key" ), inputDataDF(" eventTime" )).expr.asInstanceOf [PythonUDF ],
430- outputStructType,
431- stateStructType,
432- " Update" ,
433- " EventTimeTimeout" )
383+ val pythonScript =
384+ s """
385+ |import calendar
386+ |import os
387+ |import datetime
388+ |import pandas as pd
389+ |from pyspark.sql.types import StructType, StringType, StructField, IntegerType
390+ |
391+ |tpe = StructType([
392+ | StructField("key", StringType()),
393+ | StructField("maxEventTimeSec", IntegerType())])
394+ |
395+ |def func(key, pdf_iter, state):
396+ | assert state.getCurrentProcessingTimeMs() >= 0
397+ | assert state.getCurrentWatermarkMs() >= -1
398+ |
399+ | timeout_delay_sec = 5
400+ | if state.hasTimedOut:
401+ | state.remove()
402+ | yield pd.DataFrame({'key': [key[0]], 'maxEventTimeSec': [-1]})
403+ | else:
404+ | m = state.getOption
405+ | if m is None:
406+ | max_event_time_sec = 0
407+ | else:
408+ | max_event_time_sec = m[0]
409+ |
410+ | for pdf in pdf_iter:
411+ | pser = pdf.eventTime.apply(
412+ | lambda dt: (int(calendar.timegm(dt.utctimetuple()) + dt.microsecond)))
413+ | max_event_time_sec = int(max(pser.max(), max_event_time_sec))
414+ |
415+ | state.update((max_event_time_sec,))
416+ | timeout_timestamp_sec = max_event_time_sec + timeout_delay_sec
417+ | state.setTimeoutTimestamp( $timeoutMs)
418+ | yield pd.DataFrame({'key': [key[0]],
419+ | 'maxEventTimeSec': [max_event_time_sec]})
420+ | """ .stripMargin.format(" " )
421+ val pythonFunc = TestGroupedMapPandasUDFWithState (
422+ name = " pandas_grouped_map_with_state" , pythonScript = pythonScript)
434423
435- testStream(result, Update )(
436- StartStream (),
424+ val inputData = MemoryStream [(String , Int )]
425+ val inputDataDF =
426+ inputData.toDF().select($" _1" .as(" key" ), timestamp_seconds($" _2" ).as(" eventTime" ))
427+ val outputStructType = StructType (
428+ Seq (
429+ StructField (" key" , StringType ),
430+ StructField (" maxEventTimeSec" , IntegerType )))
431+ val stateStructType = StructType (Seq (StructField (" maxEventTimeSec" , LongType )))
432+ val result =
433+ inputDataDF
434+ .withWatermark(" eventTime" , " 10 seconds" )
435+ .groupBy(" key" )
436+ .applyInPandasWithState(
437+ pythonFunc(inputDataDF(" key" ), inputDataDF(" eventTime" )).expr.asInstanceOf [PythonUDF ],
438+ outputStructType,
439+ stateStructType,
440+ " Update" ,
441+ " EventTimeTimeout" )
437442
438- AddData (inputData, (" a" , 11 ), (" a" , 13 ), (" a" , 15 )),
439- // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5.
440- CheckNewAnswer ((" a" , 15 )), // Output = max event time of a
443+ testStream(result, Update )(
444+ StartStream (),
441445
442- AddData (inputData, (" a" , 4 )), // Add data older than watermark for "a"
443- CheckNewAnswer (), // No output as data should get filtered by watermark
446+ AddData (inputData, (" a" , 11 ), (" a" , 13 ), (" a" , 15 )),
447+ // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5.
448+ CheckNewAnswer ((" a" , 15 )), // Output = max event time of a
444449
445- AddData (inputData, (" a" , 10 )), // Add data newer than watermark for "a"
446- CheckNewAnswer ((" a" , 15 )), // Max event time is still the same
447- // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15.
448- // Watermark is still 5 as max event time for all data is still 15.
450+ AddData (inputData, (" a" , 4 )), // Add data older than watermark for "a"
451+ CheckNewAnswer (), // No output as data should get filtered by watermark
449452
450- AddData (inputData, (" b" , 31 )), // Add data newer than watermark for "b", not "a"
451- // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20.
452- CheckNewAnswer ((" a" , - 1 ), (" b" , 31 )) // State for "a" should timeout and emit -1
453- )
453+ AddData (inputData, (" a" , 10 )), // Add data newer than watermark for "a"
454+ CheckNewAnswer ((" a" , 15 )), // Max event time is still the same
455+ // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15.
456+ // Watermark is still 5 as max event time for all data is still 15.
457+
458+ AddData (inputData, (" b" , 31 )), // Add data newer than watermark for "b", not "a"
459+ // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is
460+ // 20.
461+ CheckNewAnswer ((" a" , - 1 ), (" b" , 31 )) // State for "a" should timeout and emit -1
462+ )
463+ }
454464 }
455465 }
456466
0 commit comments