Skip to content

[SPARK-49601][SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas #48005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from

Conversation

jingz-db
Copy link
Contributor

@jingz-db jingz-db commented Sep 5, 2024

What changes were proposed in this pull request?

This PR adds support for users to provide a Dataframe that can be used to instantiate state for the query in the first batch for arbitrary state API v2 in Python.
The Scala PR for supporting initial state is here: #45467

We propose to create a new PythonRunner that handles initial state specifically for TransformWithStateInPandas. From JVM, we coGroup input rows and initial state rows on the same grouping key. Then we create a new row that contains one row in the input rows iterator and one row in the initial state iterator, and send the new grouped row to Py4j. Inside the python worker, we deserialize the grouped row into input rows and initial state rows separately and input those into handleInitialState and handleInputRows.
We will launch a python worker for each partition that has a non-empty input rows in either input rows or initial states. This will guarantee all keys in the initial state will be processed even if they do not appear in the first batch or they don't lie in the same partition with keys in the first batch.

Why are the changes needed?

We need to couple the API as we support initial state handling in Scala.

Does this PR introduce any user-facing change?

Yes.
This PR introduces a new API in the StatefulProcessor which allows users to define their own udf for processing initial state:

 def handleInitialState(
        self, key: Any, initialState: "PandasDataFrameLike"
    ) -> None:

The implementation of this function is optional. If not defined, then it will act as no-op.

How was this patch tested?

Unit tests & integration tests.

Was this patch authored or co-authored using generative AI tooling?

No.

@jingz-db jingz-db changed the title [SS][PYTHON] Support Initial State for TransformWithStateInPandas [SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas Sep 5, 2024
@jingz-db jingz-db marked this pull request as ready for review September 9, 2024 16:48
@github-actions github-actions bot added the DOCS label Sep 9, 2024
@jingz-db jingz-db changed the title [SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas [SPARK-49601][SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas Sep 11, 2024
@jingz-db jingz-db force-pushed the python-init-state-impl branch from 253e56d to 099d827 Compare September 12, 2024 19:01
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
}

private def handleStatefulProcessorUtilRequest(message: UtilsCallCommand): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some scala unit tests for these 2 new APIs?

yield pd.DataFrame({"id": key, "value": str(accumulated_value)})

def handleInitialState(self, key, initialState) -> None:
initVal = initialState.at[0, "initVal"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add verifications on the initVal here?

@github-actions github-actions bot added the CORE label Oct 10, 2024
@jingz-db jingz-db requested a review from bogao007 October 11, 2024 09:27
@@ -402,6 +404,9 @@ def transformWithStateInPandas(
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.
initialState: "GroupedData"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use something like below to represent the actual type.

:class:`pyspark.sql.types.DataType`

) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's something not very clear to me here, could you help me understand more?

We only call handleInitialState when handle state is CREATED, but after we processed the initial state of the first grouping key, we update the state to be INITIALIZED. Wouldn't that skip the initial state for other grouping keys?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my understanding is correct, we should move the handleInitialState outside the handle state check, do it after the init call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct. I also moved out the code block and run a local test with partition number as "1" to confirm the implementation is correct.

statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
# for non first batch, initialStates will be None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non first batch, would initialStates be None or empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the above in the comments with other input combinations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be None. This is a bit hacky. We pass in the python eval type purely based on whether the input initialState dataframe is None or not. For non-empty input initial state and non first batch, we will still eval UDF as transformWithStateWithInitStateUDF here. As JVM will start a eval type of transformWithStateUDF PythonRunner for non first batch, we will get initialStates as None as it is the positional value: initialStates: Iterator["PandasDataFrameLike"] = None

inputRows: Iterator["PandasDataFrameLike"],
# for non first batch, initialStates will be None
initialStates: Iterator["PandasDataFrameLike"] = None
) -> Iterator["PandasDataFrameLike"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some commentss on the possible input combinations that we need to handle in this udf for people to understand easier? IIUC there should be 3 cases:

  • Both inputRows and initialStates contain data. This would only happen in the first batch and the associated grouping key contains both input data and initial state.
  • Only inputRows contains data. This could happen when either the grouping key doesn't have any initial state to process or it's non first batch.
  • Only initialStates contains data. This could happen when the grouping key doesn't have any associated input data but it has initial state to process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the above in the comment.

@jingz-db jingz-db requested a review from bogao007 October 14, 2024 18:24
Copy link
Contributor

@bogao007 bogao007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, just some nits.

seen_init_state_on_key = False
for cur_initial_state in initialStates:
if seen_init_state_on_key:
raise Exception(f"TransformWithStateWithInitState: Cannot have more "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's include the TODO for classifying the errors here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am removing this check as we'll allow multiple value rows for the same grouping key as part of the integration of supporting initial state handling with state reader source (for flattened list/map state, there will be multiple value rows with the same grouping key in the output dataframe).

@jingz-db jingz-db force-pushed the python-init-state-impl branch from 8e90c2e to 45459d9 Compare October 31, 2024 22:06
Copy link
Contributor

@HeartSaVioR HeartSaVioR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass

@@ -409,6 +410,9 @@ def transformWithStateInPandas(
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.
initialState : :class:`pyspark.sql.GroupedData`
Optional. The grouped dataframe on given grouping key as initial states used for initialization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Now the method doc for Scala version and PySpark version are diverged, not only for the type (which is expected) but also the description itself.

For example, here is the explanation of initialState in Scala API:

User provided initial state that will be used to initiate state for the query in the first batch.

Probably better to revisit both API doc at some point and sync between twos.

Before doing that, I think the part on given grouping key is redundant, and makes confusion. We should have checked the compatibility of the grouping key between two groups (current Dataset, and Dataset for initialState), right? If then we could just remove it.

"""
UDF for TWS operator with non-empty initial states. Possible input combinations
of inputRows and initialStates iterator:
- Both `inputRows` and `initialStates` are non-empty: for the given key, both input rows
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: both input rows and initial states contains the grouping key sound to be redundant since we call out for the given key. inputRows and initialStates are expected to be flatten Dataset (not grouped one), right? Their grouping key is the given key.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for all others

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points! Removed redundant words.

of inputRows and initialStates iterator:
- Both `inputRows` and `initialStates` are non-empty: for the given key, both input rows
and initial states contains the grouping key, both input rows and initial states contains data.
- `InitialStates` is non-empty, while `initialStates` is empty. For the given key, only
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: InitialStates is non-empty, while initialStates is empty.

you may want to change either one.

initial states contains the grouping key and data, and it is first batch.
- `initialStates` is empty, while `inputRows` is not empty. For the given grouping key, only inputRows
contains the grouping key and data, and it is first batch.
- `initialStates` is None, while `inputRows` is not empty. This is not first batch. `initialStates`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This represents the difference between an empty Dataset (or iterator) and None, right? Just to make clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, empty Dataset is different from None. When we are in non-first batch, initialStates will be None.


# only process initial state if first batch
is_first_batch = statefulProcessorApiClient.is_first_batch()
if is_first_batch and initialStates is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect caller to handle this; providing initialStates for non-first batch is already adding unnecessary overhead and ideally caller should provide None for non-first batch. I'm OK to double check here for safety purpose, but maybe I'd do opposite, assert that (!is_first_batch and initialStates is None) is True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we are only making an API call for safety purpose and it introduces small overhead. I am removing the check entirely as you commented below, the API itself is a bit confusing.

funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction, hasInitialState)
with PythonArrowInput[GroupedInType] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

eventTimeWatermarkForEviction: Option[Long],
hasInitialState: Boolean)
extends BasePythonRunner[I, ColumnarBatch](funcs.map(_._1), evalType, argOffsets, jobArtifactUUID)
with PythonArrowInput[I]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for all with lines

writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator:
Iterator[GroupedInType]): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shifting one line above (any reason it's placed to the next line?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the combined line exceeds 100 chars, : Boolean = { should only be in this line, with 2 spaces shifted left from parameters.

)
return table_from_fields

for batch in batches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to have a brief comment about how the batch has constructed or some characteristic, or even where to read the code to understand the data structure. Personally I read this code before reading the part of building batch, and have to make an assumption that a batch must only have data from a single grouping key, otherwise it won't work.

@@ -536,6 +536,108 @@ def check_results(batch_df, batch_id):
EventTimeStatefulProcessor(), check_results
)

def _test_transform_with_state_init_state_in_pandas(self, stateful_processor, check_results):
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you are covering both cases in this test, which is great!

  • grouping key in input, but not in initial state (1)
  • grouping key in initial state, but not in input (3)

@HeartSaVioR
Copy link
Contributor

@jingz-db jingz-db requested a review from HeartSaVioR November 4, 2024 23:07
SimpleStatefulProcessorWithInitialState(), check_results
)

def _test_transform_with_state_non_contiguous_grouping_cols(
Copy link
Contributor

@HeartSaVioR HeartSaVioR Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we have the same test (non-contiguous grouping keys) for the path of initial state for completeness sake?

Copy link
Contributor

@HeartSaVioR HeartSaVioR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second pass, I added a couple comments to address. Looks good to me otherwise.

@@ -567,7 +568,9 @@ class PythonEvalType:
SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209
SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210
SQL_TRANSFORM_WITH_STATE_PANDAS_UDF: "PandasGroupedMapUDFTransformWithStateType" = 211

SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF: "PandasGroupedMapUDFTransformWithStateInitStateType" = ( # noqa: E501
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to add the #noqa here else we won't pass ./dev/lint-python or flake8 check.

@jingz-db jingz-db requested a review from HeartSaVioR November 5, 2024 22:47
Copy link
Contributor

@HeartSaVioR HeartSaVioR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@HeartSaVioR
Copy link
Contributor

Thanks! Merging to master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants