Skip to content

[SPARK-49601][SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas #48005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ private[spark] object PythonEvalType {
val SQL_GROUPED_MAP_ARROW_UDF = 209
val SQL_COGROUPED_MAP_ARROW_UDF = 210
val SQL_TRANSFORM_WITH_STATE_PANDAS_UDF = 211
val SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF = 212

val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
Expand All @@ -84,6 +85,8 @@ private[spark] object PythonEvalType {
case SQL_TABLE_UDF => "SQL_TABLE_UDF"
case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
case SQL_TRANSFORM_WITH_STATE_PANDAS_UDF => "SQL_TRANSFORM_WITH_STATE_PANDAS_UDF"
case SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF =>
"SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ Stateful Processor

StatefulProcessor.init
StatefulProcessor.handleInputRows
StatefulProcessor.close
StatefulProcessor.close
StatefulProcessor.handleInitialState
1 change: 1 addition & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ PandasGroupedMapUDFWithStateType = Literal[208]
ArrowGroupedMapUDFType = Literal[209]
ArrowCogroupedMapUDFType = Literal[210]
PandasGroupedMapUDFTransformWithStateType = Literal[211]
PandasGroupedMapUDFTransformWithStateInitStateType = Literal[212]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
None,
Expand Down Expand Up @@ -455,6 +456,7 @@ def _validate_pandas_udf(f, evalType) -> int:
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
Expand Down
108 changes: 92 additions & 16 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#
import itertools
import sys
from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast
from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast
import warnings

from pyspark.errors import PySparkTypeError
Expand Down Expand Up @@ -373,6 +373,7 @@ def transformWithStateInPandas(
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
) -> DataFrame:
"""
Invokes methods defined in the stateful processor used in arbitrary state API v2. It
Expand Down Expand Up @@ -409,6 +410,9 @@ def transformWithStateInPandas(
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.
initialState : :class:`pyspark.sql.GroupedData`
Optional. The grouped dataframe as initial states used for initialization
of state variables in the first batch.

Examples
--------
Expand Down Expand Up @@ -493,22 +497,17 @@ def transformWithStateInPandas(
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
if initialState is not None:
assert isinstance(initialState, GroupedData)
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def transformWithStateUDF(
def handle_data_with_timers(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
statefulProcessor.init(handle)
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.INITIALIZED
)

statefulProcessorApiClient.set_implicit_key(key)

if timeMode != "none":
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp()
Expand Down Expand Up @@ -551,25 +550,102 @@ def transformWithStateUDF(
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator

result = itertools.chain(*result_iter_list)
return result

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
statefulProcessor.init(handle)
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.INITIALIZED
)

result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
return result

def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
) -> Iterator["PandasDataFrameLike"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some commentss on the possible input combinations that we need to handle in this udf for people to understand easier? IIUC there should be 3 cases:

  • Both inputRows and initialStates contain data. This would only happen in the first batch and the associated grouping key contains both input data and initial state.
  • Only inputRows contains data. This could happen when either the grouping key doesn't have any initial state to process or it's non first batch.
  • Only initialStates contains data. This could happen when the grouping key doesn't have any associated input data but it has initial state to process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the above in the comment.

"""
UDF for TWS operator with non-empty initial states. Possible input combinations
of inputRows and initialStates iterator:
- Both `inputRows` and `initialStates` are non-empty. Both input rows and initial
states contains the grouping key and data.
- `InitialStates` is non-empty, while `inputRows` is empty. Only initial states
contains the grouping key and data, and it is first batch.
- `initialStates` is empty, while `inputRows` is non-empty. Only inputRows contains the
grouping key and data, and it is first batch.
- `initialStates` is None, while `inputRows` is not empty. This is not first batch.
`initialStates` is initialized to the positional value as None.
"""
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's something not very clear to me here, could you help me understand more?

We only call handleInitialState when handle state is CREATED, but after we processed the initial state of the first grouping key, we update the state to be INITIALIZED. Wouldn't that skip the initial state for other grouping keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, we should move the handleInitialState outside the handle state check, do it after the init call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct. I also moved out the code block and run a local test with partition number as "1" to confirm the implementation is correct.

statefulProcessor.init(handle)
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.INITIALIZED
)

# only process initial state if first batch and initial state is not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
# TODO(SPARK-50194) integration with new timer API with initial state
statefulProcessor.handleInitialState(key, cur_initial_state)

# if we don't have input rows for the given key but only have initial state
# for the grouping key, the inputRows iterator could be empty
input_rows_empty = False
try:
first = next(inputRows)
except StopIteration:
input_rows_empty = True
else:
inputRows = itertools.chain([first], inputRows)

if not input_rows_empty:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, isn't there a case where inputRows iterator is empty but timer is expected to be expired?

My understanding is that if you pass transformWithStateWithInitStateUDF as udf, it will be used for all batches, right? Then how this could handle the case of grouping key for batch N where there is no data for grouping key but timer to expire?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't have a test covering this scenario, please add it as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are also right about this; that is the scenario covered in this PR: #45780. I was planning to finish the integration together with the new API for timer that Anish merged last week. I also left a TODO here few lines above. If that's OK with you, I will finish this portion in SPARK-50194.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK to defer the fix, but let's add the TODO comment to the non-initial state as well. They also have to be changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration of timer for non-initial state is already done here: https://github.com/apache/spark/pull/48005/files#diff-5862151bb5e9fe7a6b2d1978301c235504dcc6c1bbbd1f9745a204a3ba93146eR568.
We are keeping the way it was, and this is also the way for non-first batch with initial state is handled. We are only missing a corner case for the timer registered in the initial state in the TODO.

result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows)
else:
result = iter([])

return result

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

udf = pandas_udf(
transformWithStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
df = self._df

if initialState is None:
initial_state_java_obj = None
udf = pandas_udf(
transformWithStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
else:
initial_state_java_obj = initialState._jgd
udf = pandas_udf(
transformWithStateWithInitStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
)

udf_column = udf(*[df[col] for col in df.columns])

jdf = self._jgd.transformWithStateInPandas(
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
outputMode,
timeMode,
initial_state_java_obj,
)
return DataFrame(jdf, self.session)

Expand Down
76 changes: 76 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,3 +1190,79 @@ def dump_stream(self, iterator, stream):
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)


class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer):
"""
Serializer used by Python worker to evaluate UDF for
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`.
Parameters
----------
Same as input parameters in TransformWithStateInPandasSerializer.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch
)
self.init_key_offsets = None

def load_stream(self, stream):
import pyarrow as pa

def generate_data_batches(batches):
"""
Deserialize ArrowRecordBatches and return a generator of pandas.Series list.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
ordering that data chunks for same grouping key will appear sequentially.
See `TransformWithStateInPandasPythonInitialStateRunner` for arrow batch schema sent
from JVM.
This function flatten the columns of input rows and initial state rows and feed them
into the data generator.
"""

def flatten_columns(cur_batch, col_name):
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))
state_field_names = [
state_column.type[i].name for i in range(state_column.type.num_fields)
]
state_field_arrays = [
state_column.field(i) for i in range(state_column.type.num_fields)
]
table_from_fields = pa.Table.from_arrays(
state_field_arrays, names=state_field_names
)
return table_from_fields

"""
The arrow batch is written in the schema:
schema: StructType = new StructType()
.add("inputData", dataSchema)
.add("initState", initStateSchema)
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python
data generator. All rows in the same batch have the same grouping key.
"""
for batch in batches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to have a brief comment about how the batch has constructed or some characteristic, or even where to read the code to understand the data structure. Personally I read this code before reading the part of building batch, and have to make an assumption that a batch must only have data from a single grouping key, otherwise it won't work.

flatten_state_table = flatten_columns(batch, "inputData")
data_pandas = [self.arrow_to_pandas(c) for c in flatten_state_table.itercolumns()]

flatten_init_table = flatten_columns(batch, "initState")
init_data_pandas = [
self.arrow_to_pandas(c) for c in flatten_init_table.itercolumns()
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]

if any(s.empty for s in key_series):
# If any row is empty, assign batch_key using init_key_series
batch_key = tuple(s[0] for s in init_key_series)
else:
# If all rows are non-empty, create batch_key from key_series
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas, init_data_pandas)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)
7 changes: 7 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,10 @@ def close(self) -> None:
operations.
"""
...

def handleInitialState(self, key: Any, initialState: "PandasDataFrameLike") -> None:
"""
Optional to implement. Will act as no-op if not defined or no initial state input.
Function that will be invoked only in the first batch for users to process initial states.
"""
pass
Loading