-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-49601][SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas #48005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
253e56d
to
099d827
Compare
sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto
Outdated
Show resolved
Hide resolved
...main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
Outdated
Show resolved
Hide resolved
case _ => | ||
throw new IllegalArgumentException("Invalid method call") | ||
} | ||
} | ||
|
||
private def handleStatefulProcessorUtilRequest(message: UtilsCallCommand): Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add some scala unit tests for these 2 new APIs?
...main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
Outdated
Show resolved
Hide resolved
...main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
Outdated
Show resolved
Hide resolved
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
Outdated
Show resolved
Hide resolved
yield pd.DataFrame({"id": key, "value": str(accumulated_value)}) | ||
|
||
def handleInitialState(self, key, initialState) -> None: | ||
initVal = initialState.at[0, "initVal"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add verifications on the initVal here?
@@ -402,6 +404,9 @@ def transformWithStateInPandas( | |||
The output mode of the stateful processor. | |||
timeMode : str | |||
The time mode semantics of the stateful processor for timers and TTL. | |||
initialState: "GroupedData" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use something like below to represent the actual type.
:class:`pyspark.sql.types.DataType`
...main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
Outdated
Show resolved
Hide resolved
...ain/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
Outdated
Show resolved
Hide resolved
) -> Iterator["PandasDataFrameLike"]: | ||
handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
||
if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's something not very clear to me here, could you help me understand more?
We only call handleInitialState
when handle state is CREATED
, but after we processed the initial state of the first grouping key, we update the state to be INITIALIZED
. Wouldn't that skip the initial state for other grouping keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If my understanding is correct, we should move the handleInitialState
outside the handle state check, do it after the init
call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct. I also moved out the code block and run a local test with partition number as "1" to confirm the implementation is correct.
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
inputRows: Iterator["PandasDataFrameLike"], | ||
# for non first batch, initialStates will be None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For non first batch, would initialStates be None or empty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the above in the comments with other input combinations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be None. This is a bit hacky. We pass in the python eval type purely based on whether the input initialState
dataframe is None or not. For non-empty input initial state and non first batch, we will still eval UDF as transformWithStateWithInitStateUDF
here. As JVM will start a eval type of transformWithStateUDF
PythonRunner for non first batch, we will get initialStates
as None as it is the positional value: initialStates: Iterator["PandasDataFrameLike"] = None
inputRows: Iterator["PandasDataFrameLike"], | ||
# for non first batch, initialStates will be None | ||
initialStates: Iterator["PandasDataFrameLike"] = None | ||
) -> Iterator["PandasDataFrameLike"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some commentss on the possible input combinations that we need to handle in this udf for people to understand easier? IIUC there should be 3 cases:
- Both
inputRows
andinitialStates
contain data. This would only happen in the first batch and the associated grouping key contains both input data and initial state. - Only
inputRows
contains data. This could happen when either the grouping key doesn't have any initial state to process or it's non first batch. - Only
initialStates
contains data. This could happen when the grouping key doesn't have any associated input data but it has initial state to process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the above in the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, just some nits.
seen_init_state_on_key = False | ||
for cur_initial_state in initialStates: | ||
if seen_init_state_on_key: | ||
raise Exception(f"TransformWithStateWithInitState: Cannot have more " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's include the TODO for classifying the errors here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am removing this check as we'll allow multiple value rows for the same grouping key as part of the integration of supporting initial state handling with state reader source (for flattened list/map state, there will be multiple value rows with the same grouping key in the output dataframe).
8e90c2e
to
45459d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass
@@ -409,6 +410,9 @@ def transformWithStateInPandas( | |||
The output mode of the stateful processor. | |||
timeMode : str | |||
The time mode semantics of the stateful processor for timers and TTL. | |||
initialState : :class:`pyspark.sql.GroupedData` | |||
Optional. The grouped dataframe on given grouping key as initial states used for initialization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Now the method doc for Scala version and PySpark version are diverged, not only for the type (which is expected) but also the description itself.
For example, here is the explanation of initialState
in Scala API:
User provided initial state that will be used to initiate state for the query in the first batch.
Probably better to revisit both API doc at some point and sync between twos.
Before doing that, I think the part on given grouping key
is redundant, and makes confusion. We should have checked the compatibility of the grouping key between two groups (current Dataset, and Dataset for initialState), right? If then we could just remove it.
""" | ||
UDF for TWS operator with non-empty initial states. Possible input combinations | ||
of inputRows and initialStates iterator: | ||
- Both `inputRows` and `initialStates` are non-empty: for the given key, both input rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: both input rows and initial states contains the grouping key
sound to be redundant since we call out for the given key
. inputRows and initialStates are expected to be flatten Dataset (not grouped one), right? Their grouping key is the given key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for all others
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points! Removed redundant words.
of inputRows and initialStates iterator: | ||
- Both `inputRows` and `initialStates` are non-empty: for the given key, both input rows | ||
and initial states contains the grouping key, both input rows and initial states contains data. | ||
- `InitialStates` is non-empty, while `initialStates` is empty. For the given key, only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: InitialStates
is non-empty, while initialStates
is empty.
you may want to change either one.
initial states contains the grouping key and data, and it is first batch. | ||
- `initialStates` is empty, while `inputRows` is not empty. For the given grouping key, only inputRows | ||
contains the grouping key and data, and it is first batch. | ||
- `initialStates` is None, while `inputRows` is not empty. This is not first batch. `initialStates` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This represents the difference between an empty Dataset (or iterator) and None, right? Just to make clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, empty Dataset is different from None. When we are in non-first batch, initialStates
will be None.
|
||
# only process initial state if first batch | ||
is_first_batch = statefulProcessorApiClient.is_first_batch() | ||
if is_first_batch and initialStates is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd expect caller to handle this; providing initialStates for non-first batch is already adding unnecessary overhead and ideally caller should provide None for non-first batch. I'm OK to double check here for safety purpose, but maybe I'd do opposite, assert that (!is_first_batch and initialStates is None) is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we are only making an API call for safety purpose and it introduces small overhead. I am removing the check entirely as you commented below, the API itself is a bit confusing.
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId, | ||
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, | ||
batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState) | ||
with PythonArrowInput[GroupedInType] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
eventTimeWatermarkForEviction: Option[Long], | ||
hasInitialState: Boolean) | ||
extends BasePythonRunner[I, ColumnarBatch](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID) | ||
with PythonArrowInput[I] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for all with
lines
writer: ArrowStreamWriter, | ||
dataOut: DataOutputStream, | ||
inputIterator: | ||
Iterator[GroupedInType]): Boolean = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: shifting one line above (any reason it's placed to the next line?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the combined line exceeds 100 chars, : Boolean = {
should only be in this line, with 2 spaces shifted left from parameters.
) | ||
return table_from_fields | ||
|
||
for batch in batches: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to have a brief comment about how the batch has constructed or some characteristic, or even where to read the code to understand the data structure. Personally I read this code before reading the part of building batch, and have to make an assumption that a batch must only have data from a single grouping key, otherwise it won't work.
@@ -536,6 +536,108 @@ def check_results(batch_df, batch_id): | |||
EventTimeStatefulProcessor(), check_results | |||
) | |||
|
|||
def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, check_results): | |||
input_path = tempfile.mkdtemp() | |||
self._prepare_test_resource1(input_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you are covering both cases in this test, which is great!
- grouping key in input, but not in initial state (1)
- grouping key in initial state, but not in input (3)
https://github.com/jingz-db/spark/actions/runs/11620673481/job/32364544070 |
SimpleStatefulProcessorWithInitialState(), check_results | ||
) | ||
|
||
def _test_transform_with_state_non_contiguous_grouping_cols( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we have the same test (non-contiguous grouping keys) for the path of initial state for completeness sake?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second pass, I added a couple comments to address. Looks good to me otherwise.
@@ -567,7 +568,9 @@ class PythonEvalType: | |||
SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209 | |||
SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210 | |||
SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: "PandasGroupedMapUDFTransformWithStateType" = 211 | |||
|
|||
SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: "PandasGroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to add the #noqa here else we won't pass ./dev/lint-python
or flake8
check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Thanks! Merging to master. |
What changes were proposed in this pull request?
This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2 in Python.
The Scala PR for supporting initial state is here: #45467
We propose to create a new PythonRunner that handles initial state specifically for TransformWithStateInPandas. From JVM, we coGroup input rows and initial state rows on the same grouping key. Then we create a new row that contains one row in the input rows iterator and one row in the initial state iterator, and send the new grouped row to Py4j. Inside the python worker, we deserialize the grouped row into input rows and initial state rows separately and input those into
handleInitialState
andhandleInputRows
.We will launch a python worker for each partition that has a non-empty input rows in either input rows or initial states. This will guarantee all keys in the initial state will be processed even if they do not appear in the first batch or they don't lie in the same partition with keys in the first batch.
Why are the changes needed?
We need to couple the API as we support initial state handling in Scala.
Does this PR introduce any user-facing change?
Yes.
This PR introduces a new API in the
StatefulProcessor
which allows users to define their own udf for processing initial state:The implementation of this function is optional. If not defined, then it will act as no-op.
How was this patch tested?
Unit tests & integration tests.
Was this patch authored or co-authored using generative AI tooling?
No.