Skip to content

Commit

Permalink
[SPARK-40538][CONNECT] Improve built-in function support for Python c…
Browse files Browse the repository at this point in the history
…lient

### What changes were proposed in this pull request?
This patch changes the way simple scalar built-in functions are resolved in the Python Spark Connect client. Previously, it was trying to manually load specific functions. With the changes in this patch, the trivial binary operators like `<`, `+`, ... are mapped to their name equivalents in Spark so that the dynamic function lookup works.

In addition, it cleans up the Scala planner side to remove the now unnecessary code translating the trivial binary expressions into their equivalent functions.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
UT, E2E

Closes #38270 from grundprinzip/spark-40538.

Authored-by: Martin Grund <martin.grund@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
grundprinzip authored and HyukjinKwon committed Oct 18, 2022
1 parent fc4643b commit a9da924
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,44 @@ package object dsl {
.build()
}

/**
* Create an unresolved function from name parts.
*
* @param nameParts
* @param args
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addAllParts(nameParts.asJava)
.addAllArguments(args.asJava))
.build()
}

/**
* Creates an UnresolvedFunction from a single identifier.
*
* @param name
* @param args
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addParts(name)
.addAllArguments(args.asJava))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
limitExpr = expressions.Literal(limit.getLimit, IntegerType))
}

private def lookupFunction(name: String, args: Seq[Expression]): Expression = {
UnresolvedFunction(Seq(name), args, isDistinct = false)
}

/**
* Translates a scalar function from proto to the Catalyst expression.
*
Expand All @@ -211,21 +207,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
* @return
*/
private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
val funName = fun.getPartsList.asScala.mkString(".")
funName match {
case "gt" =>
assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.")
expressions.GreaterThan(
transformExpression(fun.getArguments(0)),
transformExpression(fun.getArguments(1)))
case "eq" =>
assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.")
expressions.EqualTo(
transformExpression(fun.getArguments(0)),
transformExpression(fun.getArguments(1)))
case _ =>
lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq)
if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) {
throw new IllegalArgumentException(
"Function identifier must be passed as sequence of name parts.")
}
UnresolvedFunction(
fun.getPartsList.asScala.toSeq,
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = false)
}

private def transformAlias(alias: proto.Expression.Alias): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
proto.Expression.UnresolvedFunction.newBuilder
.addAllParts(Seq("eq").asJava)
.addAllParts(Seq("==").asJava)
.addArguments(unresolvedAttribute)
.addArguments(unresolvedAttribute)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

test("UnresolvedFunction resolution.") {
{
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
assertThrows[IllegalArgumentException] {
transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr))))
}
}

val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(
connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))))
}

assertThrows[UnsupportedOperationException] {
connectPlan.analyze
}

val validPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr))))
}
assert(validPlan.analyze != null)
}

test("Basic filter") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
Expand Down
80 changes: 40 additions & 40 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,51 @@
import pyspark.sql.connect.proto as proto


def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["ColumnRef", Any], "Expression"]:
def _(self: "ColumnRef", other: Any) -> "Expression":
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
return ScalarFunctionExpression(name, self, other)
else:
return ScalarFunctionExpression(name, other, self)

return _


class Expression(object):
"""
Expression base class.
"""

__gt__ = _bin_op(">")
__lt__ = _bin_op(">")
__add__ = _bin_op("+")
__sub__ = _bin_op("-")
__mul__ = _bin_op("*")
__div__ = _bin_op("/")
__truediv__ = _bin_op("/")
__mod__ = _bin_op("%")
__radd__ = _bin_op("+", reverse=True)
__rsub__ = _bin_op("-", reverse=True)
__rmul__ = _bin_op("*", reverse=True)
__rdiv__ = _bin_op("/", reverse=True)
__rtruediv__ = _bin_op("/", reverse=True)
__pow__ = _bin_op("pow")
__rpow__ = _bin_op("pow", reverse=True)
__ge__ = _bin_op(">=")
__le__ = _bin_op("<=")

def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("==", self, other)

def __init__(self) -> None:
pass

Expand Down Expand Up @@ -73,20 +113,6 @@ def __str__(self) -> str:
return f"Literal({self._value})"


def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["ColumnRef", Any], Expression]:
def _(self: "ColumnRef", other: Any) -> Expression:
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
return ScalarFunctionExpression(name, self, other)
else:
return ScalarFunctionExpression(name, other, self)

return _


class ColumnRef(Expression):
"""Represents a column reference. There is no guarantee that this column
actually exists. In the context of this project, we refer by its name and
Expand All @@ -105,32 +131,6 @@ def name(self) -> str:
"""Returns the qualified name of the column reference."""
return ".".join(self._parts)

__gt__ = _bin_op("gt")
__lt__ = _bin_op("lt")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__truediv__ = _bin_op("divide")
__mod__ = _bin_op("modulo")
__radd__ = _bin_op("plus", reverse=True)
__rsub__ = _bin_op("minus", reverse=True)
__rmul__ = _bin_op("multiply", reverse=True)
__rdiv__ = _bin_op("divide", reverse=True)
__rtruediv__ = _bin_op("divide", reverse=True)
__pow__ = _bin_op("pow")
__rpow__ = _bin_op("pow", reverse=True)
__ge__ = _bin_op("greterEquals")
__le__ = _bin_op("lessEquals")

def __eq__(self, other: Any) -> Expression: # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("eq", self, other)

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import unittest
import tempfile

import pandas

from pyspark.sql import SparkSession, Row
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase

Expand Down Expand Up @@ -79,6 +82,15 @@ def test_simple_explain_string(self):
result = df.explain()
self.assertGreater(len(result), 0)

def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas()
self.assertEqual(len(pd.index), 4)

res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
Expand Down Expand Up @@ -51,6 +52,34 @@ def test_column_literals(self):
plan = fun.lit(10).to_plan(None)
self.assertIs(plan.literal.i32, 10)

def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
the protobuf structure."""
df = c.DataFrame.withPlan(p.Read("table"))

expr = df.id % fun.lit(10) == fun.lit(10)
expr_plan = expr.to_plan(None)
self.assertIsNotNone(expr_plan.unresolved_function)
self.assertEqual(expr_plan.unresolved_function.parts[0], "==")

lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
self.assertEqual(lit_fun.literal.i32, 10)

mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
self.assertIsInstance(mod_fun.unresolved_function, ProtoExpression.UnresolvedFunction)
self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
self.assertIsInstance(mod_fun.unresolved_function.arguments[0], ProtoExpression)
self.assertIsInstance(
mod_fun.unresolved_function.arguments[0].unresolved_attribute,
ProtoExpression.UnresolvedAttribute,
)
self.assertEqual(
mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"]
)


if __name__ == "__main__":
import unittest
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_filter(self):
plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction
)
)
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, ["gt"])
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"])
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2)

def test_relation_alias(self):
Expand Down

0 comments on commit a9da924

Please sign in to comment.