-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-49601][SS][PYTHON] Support Initial State Handling for TransformWithStateInPandas #48005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9dad0bd
4dd8179
a38d8f2
45459d9
7780e34
a2fda7a
3dbeada
67bd19b
b6500c9
b5bd82f
ee13a6a
8f182ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
# | ||
import itertools | ||
import sys | ||
from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast | ||
from typing import Any, Iterator, List, Optional, Union, TYPE_CHECKING, cast | ||
import warnings | ||
|
||
from pyspark.errors import PySparkTypeError | ||
|
@@ -373,6 +373,7 @@ def transformWithStateInPandas( | |
outputStructType: Union[StructType, str], | ||
outputMode: str, | ||
timeMode: str, | ||
initialState: Optional["GroupedData"] = None, | ||
) -> DataFrame: | ||
""" | ||
Invokes methods defined in the stateful processor used in arbitrary state API v2. It | ||
|
@@ -409,6 +410,9 @@ def transformWithStateInPandas( | |
The output mode of the stateful processor. | ||
timeMode : str | ||
The time mode semantics of the stateful processor for timers and TTL. | ||
initialState : :class:`pyspark.sql.GroupedData` | ||
Optional. The grouped dataframe as initial states used for initialization | ||
of state variables in the first batch. | ||
|
||
Examples | ||
-------- | ||
|
@@ -493,22 +497,17 @@ def transformWithStateInPandas( | |
from pyspark.sql.functions import pandas_udf | ||
|
||
assert isinstance(self, GroupedData) | ||
if initialState is not None: | ||
assert isinstance(initialState, GroupedData) | ||
if isinstance(outputStructType, str): | ||
outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
||
def transformWithStateUDF( | ||
def handle_data_with_timers( | ||
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
inputRows: Iterator["PandasDataFrameLike"], | ||
) -> Iterator["PandasDataFrameLike"]: | ||
handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
||
if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: | ||
statefulProcessor.init(handle) | ||
statefulProcessorApiClient.set_handle_state( | ||
StatefulProcessorHandleState.INITIALIZED | ||
) | ||
|
||
statefulProcessorApiClient.set_implicit_key(key) | ||
|
||
if timeMode != "none": | ||
batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() | ||
|
@@ -551,25 +550,102 @@ def transformWithStateUDF( | |
# TODO(SPARK-49603) set the handle state in the lazily initialized iterator | ||
|
||
result = itertools.chain(*result_iter_list) | ||
return result | ||
|
||
def transformWithStateUDF( | ||
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
inputRows: Iterator["PandasDataFrameLike"], | ||
) -> Iterator["PandasDataFrameLike"]: | ||
handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
||
if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: | ||
statefulProcessor.init(handle) | ||
statefulProcessorApiClient.set_handle_state( | ||
StatefulProcessorHandleState.INITIALIZED | ||
) | ||
|
||
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | ||
return result | ||
|
||
def transformWithStateWithInitStateUDF( | ||
statefulProcessorApiClient: StatefulProcessorApiClient, | ||
key: Any, | ||
inputRows: Iterator["PandasDataFrameLike"], | ||
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None, | ||
) -> Iterator["PandasDataFrameLike"]: | ||
""" | ||
UDF for TWS operator with non-empty initial states. Possible input combinations | ||
of inputRows and initialStates iterator: | ||
- Both `inputRows` and `initialStates` are non-empty. Both input rows and initial | ||
states contains the grouping key and data. | ||
- `InitialStates` is non-empty, while `inputRows` is empty. Only initial states | ||
contains the grouping key and data, and it is first batch. | ||
- `initialStates` is empty, while `inputRows` is non-empty. Only inputRows contains the | ||
grouping key and data, and it is first batch. | ||
- `initialStates` is None, while `inputRows` is not empty. This is not first batch. | ||
`initialStates` is initialized to the positional value as None. | ||
""" | ||
handle = StatefulProcessorHandle(statefulProcessorApiClient) | ||
|
||
if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's something not very clear to me here, could you help me understand more? We only call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If my understanding is correct, we should move the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are correct. I also moved out the code block and run a local test with partition number as "1" to confirm the implementation is correct. |
||
statefulProcessor.init(handle) | ||
statefulProcessorApiClient.set_handle_state( | ||
StatefulProcessorHandleState.INITIALIZED | ||
) | ||
|
||
# only process initial state if first batch and initial state is not None | ||
if initialStates is not None: | ||
for cur_initial_state in initialStates: | ||
statefulProcessorApiClient.set_implicit_key(key) | ||
# TODO(SPARK-50194) integration with new timer API with initial state | ||
statefulProcessor.handleInitialState(key, cur_initial_state) | ||
|
||
# if we don't have input rows for the given key but only have initial state | ||
# for the grouping key, the inputRows iterator could be empty | ||
input_rows_empty = False | ||
try: | ||
first = next(inputRows) | ||
except StopIteration: | ||
input_rows_empty = True | ||
else: | ||
inputRows = itertools.chain([first], inputRows) | ||
|
||
if not input_rows_empty: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, isn't there a case where inputRows iterator is empty but timer is expected to be expired? My understanding is that if you pass transformWithStateWithInitStateUDF as udf, it will be used for all batches, right? Then how this could handle the case of grouping key for batch N where there is no data for grouping key but timer to expire? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you don't have a test covering this scenario, please add it as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are also right about this; that is the scenario covered in this PR: #45780. I was planning to finish the integration together with the new API for timer that Anish merged last week. I also left a TODO here few lines above. If that's OK with you, I will finish this portion in SPARK-50194. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's OK to defer the fix, but let's add the TODO comment to the non-initial state as well. They also have to be changed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The integration of timer for non-initial state is already done here: https://github.com/apache/spark/pull/48005/files#diff-5862151bb5e9fe7a6b2d1978301c235504dcc6c1bbbd1f9745a204a3ba93146eR568. |
||
result = handle_data_with_timers(statefulProcessorApiClient, key, inputRows) | ||
else: | ||
result = iter([]) | ||
|
||
return result | ||
|
||
if isinstance(outputStructType, str): | ||
outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) | ||
|
||
udf = pandas_udf( | ||
transformWithStateUDF, # type: ignore | ||
returnType=outputStructType, | ||
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, | ||
) | ||
df = self._df | ||
|
||
if initialState is None: | ||
initial_state_java_obj = None | ||
udf = pandas_udf( | ||
transformWithStateUDF, # type: ignore | ||
returnType=outputStructType, | ||
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF, | ||
) | ||
else: | ||
initial_state_java_obj = initialState._jgd | ||
udf = pandas_udf( | ||
transformWithStateWithInitStateUDF, # type: ignore | ||
returnType=outputStructType, | ||
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF, | ||
) | ||
|
||
udf_column = udf(*[df[col] for col in df.columns]) | ||
|
||
jdf = self._jgd.transformWithStateInPandas( | ||
udf_column._jc, | ||
self.session._jsparkSession.parseDataType(outputStructType.json()), | ||
outputMode, | ||
timeMode, | ||
initial_state_java_obj, | ||
) | ||
return DataFrame(jdf, self.session) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1190,3 +1190,79 @@ def dump_stream(self, iterator, stream): | |
""" | ||
result = [(b, t) for x in iterator for y, t in x for b in y] | ||
super().dump_stream(result, stream) | ||
|
||
|
||
class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer): | ||
""" | ||
Serializer used by Python worker to evaluate UDF for | ||
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`. | ||
Parameters | ||
---------- | ||
Same as input parameters in TransformWithStateInPandasSerializer. | ||
""" | ||
|
||
def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch): | ||
super(TransformWithStateInPandasInitStateSerializer, self).__init__( | ||
timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch | ||
) | ||
self.init_key_offsets = None | ||
|
||
def load_stream(self, stream): | ||
import pyarrow as pa | ||
|
||
def generate_data_batches(batches): | ||
""" | ||
Deserialize ArrowRecordBatches and return a generator of pandas.Series list. | ||
The deserialization logic assumes that Arrow RecordBatches contain the data with the | ||
ordering that data chunks for same grouping key will appear sequentially. | ||
See `TransformWithStateInPandasPythonInitialStateRunner` for arrow batch schema sent | ||
from JVM. | ||
This function flatten the columns of input rows and initial state rows and feed them | ||
into the data generator. | ||
""" | ||
|
||
def flatten_columns(cur_batch, col_name): | ||
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name)) | ||
state_field_names = [ | ||
state_column.type[i].name for i in range(state_column.type.num_fields) | ||
] | ||
state_field_arrays = [ | ||
state_column.field(i) for i in range(state_column.type.num_fields) | ||
] | ||
table_from_fields = pa.Table.from_arrays( | ||
state_field_arrays, names=state_field_names | ||
) | ||
return table_from_fields | ||
|
||
""" | ||
The arrow batch is written in the schema: | ||
schema: StructType = new StructType() | ||
.add("inputData", dataSchema) | ||
.add("initState", initStateSchema) | ||
We'll parse batch into Tuples of (key, inputData, initState) and pass into the Python | ||
data generator. All rows in the same batch have the same grouping key. | ||
""" | ||
for batch in batches: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe better to have a brief comment about how the batch has constructed or some characteristic, or even where to read the code to understand the data structure. Personally I read this code before reading the part of building batch, and have to make an assumption that a batch must only have data from a single grouping key, otherwise it won't work. |
||
flatten_state_table = flatten_columns(batch, "inputData") | ||
data_pandas = [self.arrow_to_pandas(c) for c in flatten_state_table.itercolumns()] | ||
|
||
flatten_init_table = flatten_columns(batch, "initState") | ||
init_data_pandas = [ | ||
self.arrow_to_pandas(c) for c in flatten_init_table.itercolumns() | ||
] | ||
key_series = [data_pandas[o] for o in self.key_offsets] | ||
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets] | ||
|
||
if any(s.empty for s in key_series): | ||
# If any row is empty, assign batch_key using init_key_series | ||
batch_key = tuple(s[0] for s in init_key_series) | ||
else: | ||
# If all rows are non-empty, create batch_key from key_series | ||
batch_key = tuple(s[0] for s in key_series) | ||
yield (batch_key, data_pandas, init_data_pandas) | ||
|
||
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) | ||
data_batches = generate_data_batches(_batches) | ||
|
||
for k, g in groupby(data_batches, key=lambda x: x[0]): | ||
yield (k, g) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some commentss on the possible input combinations that we need to handle in this udf for people to understand easier? IIUC there should be 3 cases:
inputRows
andinitialStates
contain data. This would only happen in the first batch and the associated grouping key contains both input data and initial state.inputRows
contains data. This could happen when either the grouping key doesn't have any initial state to process or it's non first batch.initialStates
contains data. This could happen when the grouping key doesn't have any associated input data but it has initial state to process.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add the above in the comment.