Skip to content

[SPARK-52195][PYTHON][SS] Fix initial state column dropping issue for Python TWS #50926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def _test_transform_with_state_init_state(
time_mode="None",
checkpoint_path=None,
initial_state=None,
with_extra_transformation=False,
):
input_path = tempfile.mkdtemp()
if checkpoint_path is None:
Expand Down Expand Up @@ -798,6 +799,14 @@ def _test_transform_with_state_init_state(
initialState=initial_state,
)

if with_extra_transformation:
from pyspark.sql import functions as fn

tws_df = tws_df.select(
fn.col("id").cast("string").alias("key"),
fn.to_json(fn.struct(fn.col("value"))).alias("value"),
)

q = (
tws_df.writeStream.queryName("this_query")
.option("checkpointLocation", checkpoint_path)
Expand Down Expand Up @@ -835,6 +844,31 @@ def check_results(batch_df, batch_id):
SimpleStatefulProcessorWithInitialStateFactory(), check_results
)

def test_transform_with_state_init_state_with_extra_transformation(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
# for key 0, initial state was processed and it was only processed once;
# for key 1, it did not appear in the initial state df;
# for key 3, it did not appear in the first batch of input keys
# so it won't be emitted
assert set(batch_df.sort("key").collect()) == {
Row(key="0", value=f'{{"value":"{789 + 123 + 46}"}}'),
Row(key="1", value=f'{{"value":"{146 + 346}"}}'),
}
else:
# for key 0, verify initial state was only processed once in the first batch;
# for key 3, verify init state was processed and reflected in the accumulated value
assert set(batch_df.sort("key").collect()) == {
Row(key="0", value=f'{{"value":"{789 + 123 + 46 + 67}"}}'),
Row(key="3", value=f'{{"value":"{987 + 12}"}}'),
}

self._test_transform_with_state_init_state(
SimpleStatefulProcessorWithInitialStateFactory(),
check_results,
with_extra_transformation=True,
)

def _test_transform_with_state_non_contiguous_grouping_cols(
self, stateful_processor_factory, check_results, initial_state=None
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ case class TransformWithStateInPySpark(
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)

override lazy val references: AttributeSet =
AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) -- producedAttributes
AttributeSet(leftAttributes ++ rightReferences ++ functionExpr.references) -- producedAttributes

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithStateInPySpark =
Expand All @@ -225,6 +225,18 @@ case class TransformWithStateInPySpark(
left.output.take(groupingAttributesLen)
}
}

// Include the initial state columns in the references to avoid being column pruned.
private def rightReferences: Seq[Attribute] = {
assert(resolved, "This method is expected to be called after resolution.")
if (hasInitialState) {
right.output
} else {
// Dummy variables for passing the distribution & ordering check
// in physical operators.
left.output.take(groupingAttributesLen)
}
}
}

object TransformWithStateInPySpark {
Expand Down