Skip to content

[SPARK-52195][PYTHON][SS] Fix initial state column dropping issue for Python TWS #50926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

bogao007
Copy link
Contributor

@bogao007 bogao007 commented May 16, 2025

What changes were proposed in this pull request?

Fix initial state column dropping issue for Python TWS. This may occur when user adds extra transformations after TransformWithStateInPandas operator and those initial state columns will get pruned during optimization.

Why are the changes needed?

This prevents users to use initial state with TransformWithStateInPandas if they require extra transformations.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Added unit test case.

Was this patch authored or co-authored using generative AI tooling?

No.

@@ -811,7 +811,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
isStreaming = true,
hasInitialState,
planLater(initialState),
t.rightAttributes,
t.rightAttributes(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we have changed the rightAttributes to take parameters, even though the parameter has a default value, the compiler still requires parentheses

Copy link
Contributor

@jingz-db jingz-db May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we pass includesInitialStateColumns as false here while we pass includesInitialStateColumns as true inside references?

Copy link
Contributor Author

@bogao007 bogao007 May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to pass initialStateGroupingAttrs as the input of TransformWithStateInPySparkExec which should not include other initial state columns. We only need to add these columns in references.

assert(resolved, "This method is expected to be called after resolution.")
if (hasInitialState) {
right.output.take(initGroupingAttrsLen)
if (includesInitialStateColumns) {
// Include the initial state columns in the references to avoid being column pruned.
Copy link
Contributor

@jingz-db jingz-db May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly for your PR descrption, the column pruning happens inside optimizer? Do you have a code pointer to where in the optimizer that the column get pruned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it happened when Spark applies ColumnPruning rule. Since we didn't add these columns to references, ColumnPruning rule thinks these columns can be dropped.

Copy link
Contributor

@jingz-db jingz-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved and left some (non-blocking but) curious questions. Thanks for making the change! It is a difficult debug and thanks for your efforts!

@HyukjinKwon
Copy link
Member

cc @HeartSaVioR

Copy link
Contributor

@HeartSaVioR HeartSaVioR left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. Just a one suggestion. I'm not enforcing this - I just feel it'd be more clear. I'm OK if folks think this doesn't need more change.

@@ -215,10 +216,15 @@ case class TransformWithStateInPySpark(
left.output.take(groupingAttributesLen)
}

def rightAttributes: Seq[Attribute] = {
def rightAttributes(includesInitialStateColumns: Boolean = false): Seq[Attribute] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not make a single method to do two different purposes. Shall we have rightReferences to cover the new case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, updated

@HeartSaVioR
Copy link
Contributor

@bogao007 Sorry, but could you please re-trigger the CI via empty commit, or just this module https://github.com/bogao007/spark/actions/runs/15123895213/job/42527399451 in the Github UI? I'd like to make sure any relevant modules aren't failing. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants