Skip to content

Commit 2820397

Browse files
committed
[SPARK-41319][CONNECT][PYTHON] Implement Column.{when, otherwise} and Function when with UnresolvedFunction
### What changes were proposed in this pull request? 1, Implement `Column.{when, otherwise}` and Function `when` ### Why are the changes needed? For API coverage ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added UT Closes #38956 from zhengruifeng/connect_column_case_when_with_function. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 2de0d45 commit 2820397

File tree

6 files changed

+326
-60
lines changed

6 files changed

+326
-60
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,13 @@ class SparkConnectPlanner(session: SparkSession) {
551551
.Product(transformExpression(fun.getArgumentsList.asScala.head))
552552
.toAggregateExpression())
553553

554+
case "when" =>
555+
if (fun.getArgumentsCount == 0) {
556+
throw InvalidPlanInput("CaseWhen requires at least one child expression")
557+
}
558+
val children = fun.getArgumentsList.asScala.toSeq.map(transformExpression)
559+
Some(CaseWhen.createFromParser(children))
560+
554561
case "in" =>
555562
if (fun.getArgumentsCount == 0) {
556563
throw InvalidPlanInput("In requires at least one child expression")

python/pyspark/sql/connect/column.py

Lines changed: 157 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,18 @@
1515
# limitations under the License.
1616
#
1717

18-
from typing import get_args, TYPE_CHECKING, Callable, Any, Union, overload, cast, Sequence
18+
from typing import (
19+
get_args,
20+
TYPE_CHECKING,
21+
Callable,
22+
Any,
23+
Union,
24+
overload,
25+
cast,
26+
Sequence,
27+
Tuple,
28+
Optional,
29+
)
1930

2031
import json
2132
import decimal
@@ -152,6 +163,44 @@ def name(self) -> str:
152163
...
153164

154165

166+
class CaseWhen(Expression):
167+
def __init__(
168+
self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression]
169+
):
170+
171+
assert isinstance(branches, list)
172+
for branch in branches:
173+
assert (
174+
isinstance(branch, tuple)
175+
and len(branch) == 2
176+
and all(isinstance(expr, Expression) for expr in branch)
177+
)
178+
self._branches = branches
179+
180+
if else_value is not None:
181+
assert isinstance(else_value, Expression)
182+
183+
self._else_value = else_value
184+
185+
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
186+
args = []
187+
for condition, value in self._branches:
188+
args.append(condition)
189+
args.append(value)
190+
191+
if self._else_value is not None:
192+
args.append(self._else_value)
193+
194+
unresolved_function = UnresolvedFunction(name="when", args=args)
195+
196+
return unresolved_function.to_plan(session)
197+
198+
def __repr__(self) -> str:
199+
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
200+
_else = f" ELSE {self._else_value}" if self._else_value is not None else ""
201+
return "CASE" + _cases + _else + " END"
202+
203+
155204
class ColumnAlias(Expression):
156205
def __init__(self, parent: Expression, alias: list[str], metadata: Any):
157206

@@ -706,6 +755,113 @@ def contains(self, other: Union["PrimitiveType", "Column"]) -> "Column":
706755
startswith = _bin_op("startsWith", _startswith_doc)
707756
endswith = _bin_op("endsWith", _endswith_doc)
708757

758+
def when(self, condition: "Column", value: Any) -> "Column":
759+
"""
760+
Evaluates a list of conditions and returns one of multiple possible result expressions.
761+
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
762+
763+
.. versionadded:: 3.4.0
764+
765+
Parameters
766+
----------
767+
condition : :class:`Column`
768+
a boolean :class:`Column` expression.
769+
value
770+
a literal value, or a :class:`Column` expression.
771+
772+
Returns
773+
-------
774+
:class:`Column`
775+
Column representing whether each element of Column is in conditions.
776+
777+
Examples
778+
--------
779+
>>> from pyspark.sql import functions as F
780+
>>> df = spark.createDataFrame(
781+
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
782+
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
783+
+-----+------------------------------------------------------------+
784+
| name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
785+
+-----+------------------------------------------------------------+
786+
|Alice| -1|
787+
| Bob| 1|
788+
+-----+------------------------------------------------------------+
789+
790+
See Also
791+
--------
792+
pyspark.sql.functions.when
793+
"""
794+
if not isinstance(condition, Column):
795+
raise TypeError("condition should be a Column")
796+
797+
if not isinstance(self._expr, CaseWhen):
798+
raise TypeError(
799+
"when() can only be applied on a Column previously generated by when() function"
800+
)
801+
802+
if self._expr._else_value is not None:
803+
raise TypeError("when() cannot be applied once otherwise() is applied")
804+
805+
if isinstance(value, Column):
806+
_value = value._expr
807+
else:
808+
_value = LiteralExpression(value, LiteralExpression._infer_type(value))
809+
810+
_branches = self._expr._branches + [(condition._expr, _value)]
811+
812+
return Column(CaseWhen(branches=_branches, else_value=None))
813+
814+
def otherwise(self, value: Any) -> "Column":
815+
"""
816+
Evaluates a list of conditions and returns one of multiple possible result expressions.
817+
If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
818+
819+
.. versionadded:: 3.4.0
820+
821+
Parameters
822+
----------
823+
value
824+
a literal value, or a :class:`Column` expression.
825+
826+
Returns
827+
-------
828+
:class:`Column`
829+
Column representing whether each element of Column is unmatched conditions.
830+
831+
Examples
832+
--------
833+
>>> from pyspark.sql import functions as F
834+
>>> df = spark.createDataFrame(
835+
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
836+
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
837+
+-----+-------------------------------------+
838+
| name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
839+
+-----+-------------------------------------+
840+
|Alice| 0|
841+
| Bob| 1|
842+
+-----+-------------------------------------+
843+
844+
See Also
845+
--------
846+
pyspark.sql.functions.when
847+
"""
848+
if not isinstance(self._expr, CaseWhen):
849+
raise TypeError(
850+
"otherwise() can only be applied on a Column previously generated by when()"
851+
)
852+
853+
if self._expr._else_value is not None:
854+
raise TypeError(
855+
"otherwise() can only be applied once on a Column previously generated by when()"
856+
)
857+
858+
if isinstance(value, Column):
859+
_value = value._expr
860+
else:
861+
_value = LiteralExpression(value, LiteralExpression._infer_type(value))
862+
863+
return Column(CaseWhen(branches=self._expr._branches, else_value=_value))
864+
709865
def like(self: "Column", other: str) -> "Column":
710866
"""
711867
SQL like expression. Returns a boolean :class:`Column` based on a SQL LIKE match.
@@ -902,9 +1058,6 @@ def cast(self, dataType: Union[DataType, str]) -> "Column":
9021058
def __repr__(self) -> str:
9031059
return "Column<'%s'>" % self._expr.__repr__()
9041060

905-
def otherwise(self, *args: Any, **kwargs: Any) -> None:
906-
raise NotImplementedError("otherwise() is not yet implemented.")
907-
9081061
def over(self, *args: Any, **kwargs: Any) -> None:
9091062
raise NotImplementedError("over() is not yet implemented.")
9101063

@@ -943,9 +1096,6 @@ def isin(self, *cols: Any) -> "Column":
9431096

9441097
return Column(UnresolvedFunction("in", [self._expr] + [lit(c)._expr for c in _cols]))
9451098

946-
def when(self, *args: Any, **kwargs: Any) -> None:
947-
raise NotImplementedError("when() is not yet implemented.")
948-
9491099
def getItem(self, *args: Any, **kwargs: Any) -> None:
9501100
raise NotImplementedError("getItem() is not yet implemented.")
9511101

python/pyspark/sql/connect/functions.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717
from pyspark.sql.connect.column import (
1818
Column,
19+
CaseWhen,
1920
Expression,
2021
LiteralExpression,
2122
ColumnReference,
@@ -549,53 +550,53 @@ def spark_partition_id() -> Column:
549550
return _invoke_function("spark_partition_id")
550551

551552

552-
# TODO(SPARK-41319): Support case-when in Column
553-
# def when(condition: Column, value: Any) -> Column:
554-
# """Evaluates a list of conditions and returns one of multiple possible result expressions.
555-
# If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
556-
# conditions.
557-
#
558-
# .. versionadded:: 3.4.0
559-
#
560-
# Parameters
561-
# ----------
562-
# condition : :class:`~pyspark.sql.Column`
563-
# a boolean :class:`~pyspark.sql.Column` expression.
564-
# value :
565-
# a literal value, or a :class:`~pyspark.sql.Column` expression.
566-
#
567-
# Returns
568-
# -------
569-
# :class:`~pyspark.sql.Column`
570-
# column representing when expression.
571-
#
572-
# Examples
573-
# --------
574-
# >>> df = spark.range(3)
575-
# >>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
576-
# +---+
577-
# |age|
578-
# +---+
579-
# | 4|
580-
# | 4|
581-
# | 3|
582-
# +---+
583-
#
584-
# >>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
585-
# +----+
586-
# | age|
587-
# +----+
588-
# |null|
589-
# |null|
590-
# | 3|
591-
# +----+
592-
# """
593-
# # Explicitly not using ColumnOrName type here to make reading condition less opaque
594-
# if not isinstance(condition, Column):
595-
# raise TypeError("condition should be a Column")
596-
# v = value._jc if isinstance(value, Column) else value
597-
#
598-
# return _invoke_function("when", condition._jc, v)
553+
def when(condition: Column, value: Any) -> Column:
554+
"""Evaluates a list of conditions and returns one of multiple possible result expressions.
555+
If :func:`pyspark.sql.Column.otherwise` is not invoked, None is returned for unmatched
556+
conditions.
557+
558+
.. versionadded:: 3.4.0
559+
560+
Parameters
561+
----------
562+
condition : :class:`~pyspark.sql.Column`
563+
a boolean :class:`~pyspark.sql.Column` expression.
564+
value :
565+
a literal value, or a :class:`~pyspark.sql.Column` expression.
566+
567+
Returns
568+
-------
569+
:class:`~pyspark.sql.Column`
570+
column representing when expression.
571+
572+
Examples
573+
--------
574+
>>> df = spark.range(3)
575+
>>> df.select(when(df['id'] == 2, 3).otherwise(4).alias("age")).show()
576+
+---+
577+
|age|
578+
+---+
579+
| 4|
580+
| 4|
581+
| 3|
582+
+---+
583+
584+
>>> df.select(when(df.id == 2, df.id + 1).alias("age")).show()
585+
+----+
586+
| age|
587+
+----+
588+
|null|
589+
|null|
590+
| 3|
591+
+----+
592+
"""
593+
# Explicitly not using ColumnOrName type here to make reading condition less opaque
594+
if not isinstance(condition, Column):
595+
raise TypeError("condition should be a Column")
596+
597+
value_col = value if isinstance(value, Column) else lit(value)
598+
599+
return Column(CaseWhen(branches=[(condition._expr, value_col._expr)], else_value=None))
599600

600601

601602
# Sort Functions

python/pyspark/sql/tests/connect/test_connect_column.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,7 @@ def test_unsupported_functions(self):
347347
# SPARK-41225: Disable unsupported functions.
348348
c = self.connect.range(1).id
349349
for f in (
350-
"otherwise",
351350
"over",
352-
"when",
353351
"getItem",
354352
"astype",
355353
"between",

0 commit comments

Comments
 (0)