Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ message Expression {
UnresolvedStar unresolved_star = 5;
Alias alias = 6;
Cast cast = 7;
CaseWhen case_when = 8;
}

message Cast {
Expand Down Expand Up @@ -180,4 +181,18 @@ message Expression {
// (Optional) Alias metadata expressed as a JSON map.
optional string metadata = 3;
}

// Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
message CaseWhen {
// (Required) The seq of (branch condition, branch value)
repeated Branch branches = 1;

// (Optional) Value for the else branch.
Expression else_value = 2;

message Branch {
Expression condition = 1;
Expression value = 2;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ private[spark] object Connect {

val CONNECT_GRPC_ARROW_MAX_BATCH_SIZE =
ConfigBuilder("spark.connect.grpc.arrow.maxBatchSize")
.doc("When using Apache Arrow, limit the maximum size of one arrow batch that " +
"can be sent from server side to client side. Currently, we conservatively use 70% " +
"of it because the size is not accurate but estimated.")
.doc(
"When using Apache Arrow, limit the maximum size of one arrow batch that " +
"can be sent from server side to client side. Currently, we conservatively use 70% " +
"of it because the size is not accurate but estimated.")
.version("3.4.0")
.bytesConf(ByteUnit.MiB)
.createWithDefaultString("4m")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Expression.ExprTypeCase.UNRESOLVED_STAR =>
transformUnresolvedStar(exp.getUnresolvedStar)
case proto.Expression.ExprTypeCase.CAST => transformCast(exp.getCast)
case proto.Expression.ExprTypeCase.CASE_WHEN => transformCaseWhen(exp.getCaseWhen)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -530,6 +531,14 @@ class SparkConnectPlanner(session: SparkSession) {
}
}

private def transformCaseWhen(casewhen: proto.Expression.CaseWhen): Expression = {
CaseWhen(
branches = casewhen.getBranchesList.asScala.toSeq
.map(b => (transformExpression(b.getCondition), transformExpression(b.getValue))),
elseValue =
if (casewhen.hasElseValue) Some(transformExpression(casewhen.getElseValue)) else None)
}

private def transformSetOperation(u: proto.SetOperation): LogicalPlan = {
assert(u.hasLeftInput && u.hasRightInput, "Union must have 2 inputs")

Expand Down
158 changes: 157 additions & 1 deletion python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@
# limitations under the License.
#

from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload, cast, Sequence
from typing import (
get_args,
TYPE_CHECKING,
Callable,
Any,
Union,
overload,
cast,
Sequence,
Tuple,
Optional,
)

import json
import decimal
Expand Down Expand Up @@ -130,6 +141,44 @@ def name(self) -> str:
...


class CaseWhen(Expression):
def __init__(
self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression]
):

assert isinstance(branches, list)
for branch in branches:
assert (
isinstance(branch, tuple)
and len(branch) == 2
and all(isinstance(expr, Expression) for expr in branch)
)
self._branches = branches

if else_value is not None:
assert isinstance(else_value, Expression)

self._else_value = else_value

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
for condition, value in self._branches:
branch = proto.Expression.CaseWhen.Branch()
branch.condition.CopyFrom(condition.to_plan(session))
branch.value.CopyFrom(value.to_plan(session))
expr.case_when.branches.append(branch)

if self._else_value is not None:
expr.case_when.else_value.CopyFrom(self._else_value.to_plan(session))

return expr

def __repr__(self) -> str:
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
_else = f" ELSE {self._else_value}" if self._else_value is not None else ""
return "CASE" + _cases + _else + " END"


class ColumnAlias(Expression):
def __init__(self, parent: Expression, alias: list[str], metadata: Any):

Expand Down Expand Up @@ -591,6 +640,113 @@ def contains(self, other: Union[PrimitiveType, "Column"]) -> "Column":
startswith = _bin_op("startsWith", _startswith_doc)
endswith = _bin_op("endsWith", _endswith_doc)

def when(self, condition: "Column", value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

.. versionadded:: 3.4.0

Parameters
----------
condition : :class:`Column`
a boolean :class:`Column` expression.
value
a literal value, or a :class:`Column` expression.

Returns
-------
:class:`Column`
Column representing whether each element of Column is in conditions.

Examples
--------
>>> from pyspark.sql import functions as F
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
+-----+------------------------------------------------------------+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+-----+------------------------------------------------------------+
|Alice| -1|
| Bob| 1|
+-----+------------------------------------------------------------+

See Also
--------
pyspark.sql.functions.when
"""
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")

if not isinstance(self._expr, CaseWhen):
raise TypeError(
"when() can only be applied on a Column previously generated by when() function"
)

if self._expr._else_value is not None:
raise TypeError("when() cannot be applied once otherwise() is applied")

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression(value)

_branches = self._expr._branches + [(condition._expr, _value)]

return Column(CaseWhen(branches=_branches, else_value=None))

def otherwise(self, value: Any) -> "Column":
"""
Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.

.. versionadded:: 3.4.0

Parameters
----------
value
a literal value, or a :class:`Column` expression.

Returns
-------
:class:`Column`
Column representing whether each element of Column is unmatched conditions.

Examples
--------
>>> from pyspark.sql import functions as F
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
+-----+-------------------------------------+
| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+-----+-------------------------------------+
|Alice| 0|
| Bob| 1|
+-----+-------------------------------------+

See Also
--------
pyspark.sql.functions.when
"""
if not isinstance(self._expr, CaseWhen):
raise TypeError(
"otherwise() can only be applied on a Column previously generated by when()"
)

if self._expr._else_value is not None:
raise TypeError(
"otherwise() can only be applied once on a Column previously generated by when()"
)

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression(value)

return Column(CaseWhen(branches=self._expr._branches, else_value=_value))

def like(self: "Column", other: str) -> "Column":
"""
SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match.
Expand Down
95 changes: 48 additions & 47 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
from pyspark.sql.connect.column import (
Column,
CaseWhen,
Expression,
LiteralExpression,
ColumnReference,
Expand Down Expand Up @@ -546,53 +547,53 @@ def spark_partition_id() -> Column:
return _invoke_function("spark_partition_id")


# TODO(SPARK-41319): Support case-when in Column
# def when(condition: Column, value: Any) -> Column:
# """Evaluates a list of conditions and returns one of multiple possible result expressions.
# If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
# conditions.
#
# .. versionadded:: 3.4.0
#
# Parameters
# ----------
# condition : :class:`~pyspark.sql.Column`
# a boolean :class:`~pyspark.sql.Column` expression.
# value :
# a literal value, or a :class:`~pyspark.sql.Column` expression.
#
# Returns
# -------
# :class:`~pyspark.sql.Column`
# column representing when expression.
#
# Examples
# --------
# >>> df = spark.range(3)
# >>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
# +---+
# |age|
# +---+
# | 4|
# | 4|
# | 3|
# +---+
#
# >>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
# +----+
# | age|
# +----+
# |null|
# |null|
# | 3|
# +----+
# """
# # Explicitly not using ColumnOrName type here to make reading condition less opaque
# if not isinstance(condition, Column):
# raise TypeError("condition should be a Column")
# v = value._jc if isinstance(value, Column) else value
#
# return _invoke_function("when", condition._jc, v)
def when(condition: Column, value: Any) -> Column:
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
conditions.

.. versionadded:: 3.4.0

Parameters
----------
condition : :class:`~pyspark.sql.Column`
a boolean :class:`~pyspark.sql.Column` expression.
value :
a literal value, or a :class:`~pyspark.sql.Column` expression.

Returns
-------
:class:`~pyspark.sql.Column`
column representing when expression.

Examples
--------
>>> df = spark.range(3)
>>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
+---+
|age|
+---+
| 4|
| 4|
| 3|
+---+

>>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
+----+
| age|
+----+
|null|
|null|
| 3|
+----+
"""
# Explicitly not using ColumnOrName type here to make reading condition less opaque
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")

value_expr = value._expr if isinstance(value, Column) else LiteralExpression(value)

return Column(CaseWhen(branches=[(condition._expr, value_expr)], else_value=None))


# Sort Functions
Expand Down
Loading