Skip to content

Commit e7d4a75

Browse files
xinrong-mengdongjoon-hyun
authored andcommitted
[SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
### What changes were proposed in this pull request? Standardize registered pickled Python UDFs, specifically, implement `spark.udf.register()`. ### Why are the changes needed? To reach parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. `spark.udf.register()` is added as shown below: ```py >>> spark.udf <pyspark.sql.connect.udf.UDFRegistration object at 0x7fbca0077dc0> >>> f = spark.udf.register("f", lambda x: x+1, "int") >>> f <function <lambda> at 0x7fbc905e5e50> >>> spark.sql("SELECT f(id) FROM range(2)").collect() [Row(f(id)=1), Row(f(id)=2)] ``` ### How was this patch tested? Unit tests. Closes apache#39860 from xinrong-meng/connect_registered_udf. Lead-authored-by: Xinrong Meng <xinrong@apache.org> Co-authored-by: Xinrong Meng <xinrong.apache@gmail.com> Signed-off-by: Xinrong Meng <xinrong@apache.org> (cherry picked from commit e7eb836) Signed-off-by: Xinrong Meng <xinrong@apache.org>
1 parent 449099c commit e7d4a75

File tree

12 files changed

+216
-35
lines changed

12 files changed

+216
-35
lines changed

connector/connect/common/src/main/protobuf/spark/connect/commands.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
3131
// produce a relational result.
3232
message Command {
3333
oneof command_type {
34+
CommonInlineUserDefinedFunction register_function = 1;
3435
WriteOperation write_operation = 2;
3536
CreateDataFrameViewCommand create_dataframe_view = 3;
3637
WriteOperationV2 write_operation_v2 = 4;

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
4444
import org.apache.spark.sql.execution.QueryExecution
4545
import org.apache.spark.sql.execution.arrow.ArrowConverters
4646
import org.apache.spark.sql.execution.command.CreateViewCommand
47+
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
4748
import org.apache.spark.sql.functions.{col, expr}
4849
import org.apache.spark.sql.internal.CatalogImpl
4950
import org.apache.spark.sql.types._
@@ -1399,6 +1400,8 @@ class SparkConnectPlanner(val session: SparkSession) {
13991400

14001401
def process(command: proto.Command): Unit = {
14011402
command.getCommandTypeCase match {
1403+
case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
1404+
handleRegisterUserDefinedFunction(command.getRegisterFunction)
14021405
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
14031406
handleWriteOperation(command.getWriteOperation)
14041407
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
@@ -1411,6 +1414,36 @@ class SparkConnectPlanner(val session: SparkSession) {
14111414
}
14121415
}
14131416

1417+
private def handleRegisterUserDefinedFunction(
1418+
fun: proto.CommonInlineUserDefinedFunction): Unit = {
1419+
fun.getFunctionCase match {
1420+
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
1421+
handleRegisterPythonUDF(fun)
1422+
case _ =>
1423+
throw InvalidPlanInput(
1424+
s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
1425+
}
1426+
}
1427+
1428+
private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
1429+
val udf = fun.getPythonUdf
1430+
val function = transformPythonFunction(udf)
1431+
val udpf = UserDefinedPythonFunction(
1432+
name = fun.getFunctionName,
1433+
func = function,
1434+
dataType = DataType.parseTypeWithFallback(
1435+
schema = udf.getOutputType,
1436+
parser = DataType.fromDDL,
1437+
fallbackParser = DataType.fromJson) match {
1438+
case s: DataType => s
1439+
case other => throw InvalidPlanInput(s"Invalid return type $other")
1440+
},
1441+
pythonEvalType = udf.getEvalType,
1442+
udfDeterministic = fun.getDeterministic)
1443+
1444+
session.udf.registerPython(fun.getFunctionName, udpf)
1445+
}
1446+
14141447
private def handleCommandPlugin(extension: ProtoAny): Unit = {
14151448
SparkConnectPluginRegistry.commandRegistry
14161449
// Lazily traverse the collection.

python/pyspark/sql/connect/client.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import urllib.parse
3131
import uuid
3232
import json
33+
import sys
3334
from types import TracebackType
3435
from typing import (
3536
Iterable,
@@ -67,11 +68,18 @@
6768
TempTableAlreadyExistsException,
6869
IllegalArgumentException,
6970
)
71+
from pyspark.sql.connect.expressions import (
72+
PythonUDF,
73+
CommonInlineUserDefinedFunction,
74+
)
7075
from pyspark.sql.types import (
7176
DataType,
7277
StructType,
7378
StructField,
7479
)
80+
from pyspark.sql.utils import is_remote
81+
from pyspark.serializers import CloudPickleSerializer
82+
from pyspark.rdd import PythonEvalType
7583

7684

7785
def _configure_logging() -> logging.Logger:
@@ -428,6 +436,57 @@ def __init__(
428436
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
429437
# Configure logging for the SparkConnect client.
430438

439+
def register_udf(
440+
self,
441+
function: Any,
442+
return_type: Union[str, DataType],
443+
name: Optional[str] = None,
444+
eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
445+
deterministic: bool = True,
446+
) -> str:
447+
"""Create a temporary UDF in the session catalog on the other side. We generate a
448+
temporary name for it."""
449+
450+
from pyspark.sql import SparkSession as PySparkSession
451+
452+
if name is None:
453+
name = f"fun_{uuid.uuid4().hex}"
454+
455+
# convert str return_type to DataType
456+
if isinstance(return_type, str):
457+
458+
assert is_remote()
459+
return_type_schema = ( # a workaround to parse the DataType from DDL strings
460+
PySparkSession.builder.getOrCreate()
461+
.createDataFrame(data=[], schema=return_type)
462+
.schema
463+
)
464+
assert len(return_type_schema.fields) == 1, "returnType should be singular"
465+
return_type = return_type_schema.fields[0].dataType
466+
467+
# construct a PythonUDF
468+
py_udf = PythonUDF(
469+
output_type=return_type.json(),
470+
eval_type=eval_type,
471+
command=CloudPickleSerializer().dumps((function, return_type)),
472+
python_ver="%d.%d" % sys.version_info[:2],
473+
)
474+
475+
# construct a CommonInlineUserDefinedFunction
476+
fun = CommonInlineUserDefinedFunction(
477+
function_name=name,
478+
deterministic=deterministic,
479+
arguments=[],
480+
function=py_udf,
481+
).to_command(self)
482+
483+
# construct the request
484+
req = self._execute_plan_request_with_metadata()
485+
req.plan.command.register_function.CopyFrom(fun)
486+
487+
self._execute(req)
488+
return name
489+
431490
def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
432491
return [
433492
PlanMetrics(

python/pyspark/sql/connect/expressions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,13 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
542542
)
543543
return expr
544544

545+
def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
546+
expr = proto.CommonInlineUserDefinedFunction()
547+
expr.function_name = self._function_name
548+
expr.deterministic = self._deterministic
549+
expr.python_udf.CopyFrom(self._function.to_plan(session))
550+
return expr
551+
545552
def __repr__(self) -> str:
546553
return (
547554
f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])}), "

python/pyspark/sql/connect/proto/commands_pb2.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737

3838
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
39-
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcc\x02\n\x07\x43ommand\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
39+
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xe6\x05\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06source\x18\x02 \x01(\tR\x06source\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12\x1f\n\ntable_name\x18\x04 \x01(\tH\x00R\ttableName\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_type"\x9b\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1a\n\x08provider\x18\x03 \x01(\tR\x08provider\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
4040
)
4141

4242

@@ -147,23 +147,23 @@
147147
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
148148
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
149149
_COMMAND._serialized_start = 166
150-
_COMMAND._serialized_end = 498
151-
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 501
152-
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 651
153-
_WRITEOPERATION._serialized_start = 654
154-
_WRITEOPERATION._serialized_end = 1396
155-
_WRITEOPERATION_OPTIONSENTRY._serialized_start = 1092
156-
_WRITEOPERATION_OPTIONSENTRY._serialized_end = 1150
157-
_WRITEOPERATION_BUCKETBY._serialized_start = 1152
158-
_WRITEOPERATION_BUCKETBY._serialized_end = 1243
159-
_WRITEOPERATION_SAVEMODE._serialized_start = 1246
160-
_WRITEOPERATION_SAVEMODE._serialized_end = 1383
161-
_WRITEOPERATIONV2._serialized_start = 1399
162-
_WRITEOPERATIONV2._serialized_end = 2194
163-
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1092
164-
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1150
165-
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 1966
166-
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2032
167-
_WRITEOPERATIONV2_MODE._serialized_start = 2035
168-
_WRITEOPERATIONV2_MODE._serialized_end = 2194
150+
_COMMAND._serialized_end = 593
151+
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
152+
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
153+
_WRITEOPERATION._serialized_start = 749
154+
_WRITEOPERATION._serialized_end = 1491
155+
_WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
156+
_WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
157+
_WRITEOPERATION_BUCKETBY._serialized_start = 1247
158+
_WRITEOPERATION_BUCKETBY._serialized_end = 1338
159+
_WRITEOPERATION_SAVEMODE._serialized_start = 1341
160+
_WRITEOPERATION_SAVEMODE._serialized_end = 1478
161+
_WRITEOPERATIONV2._serialized_start = 1494
162+
_WRITEOPERATIONV2._serialized_end = 2289
163+
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
164+
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
165+
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
166+
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
167+
_WRITEOPERATIONV2_MODE._serialized_start = 2130
168+
_WRITEOPERATIONV2_MODE._serialized_end = 2289
169169
# @@protoc_insertion_point(module_scope)

python/pyspark/sql/connect/proto/commands_pb2.pyi

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,16 @@ class Command(google.protobuf.message.Message):
5959

6060
DESCRIPTOR: google.protobuf.descriptor.Descriptor
6161

62+
REGISTER_FUNCTION_FIELD_NUMBER: builtins.int
6263
WRITE_OPERATION_FIELD_NUMBER: builtins.int
6364
CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
6465
WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
6566
EXTENSION_FIELD_NUMBER: builtins.int
6667
@property
68+
def register_function(
69+
self,
70+
) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: ...
71+
@property
6772
def write_operation(self) -> global___WriteOperation: ...
6873
@property
6974
def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
@@ -77,6 +82,8 @@ class Command(google.protobuf.message.Message):
7782
def __init__(
7883
self,
7984
*,
85+
register_function: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
86+
| None = ...,
8087
write_operation: global___WriteOperation | None = ...,
8188
create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
8289
write_operation_v2: global___WriteOperationV2 | None = ...,
@@ -91,6 +98,8 @@ class Command(google.protobuf.message.Message):
9198
b"create_dataframe_view",
9299
"extension",
93100
b"extension",
101+
"register_function",
102+
b"register_function",
94103
"write_operation",
95104
b"write_operation",
96105
"write_operation_v2",
@@ -106,6 +115,8 @@ class Command(google.protobuf.message.Message):
106115
b"create_dataframe_view",
107116
"extension",
108117
b"extension",
118+
"register_function",
119+
b"register_function",
109120
"write_operation",
110121
b"write_operation",
111122
"write_operation_v2",
@@ -115,7 +126,11 @@ class Command(google.protobuf.message.Message):
115126
def WhichOneof(
116127
self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
117128
) -> typing_extensions.Literal[
118-
"write_operation", "create_dataframe_view", "write_operation_v2", "extension"
129+
"register_function",
130+
"write_operation",
131+
"create_dataframe_view",
132+
"write_operation_v2",
133+
"extension",
119134
] | None: ...
120135

121136
global___Command = Command

python/pyspark/sql/connect/session.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
if TYPE_CHECKING:
6969
from pyspark.sql.connect._typing import OptionalPrimitiveType
7070
from pyspark.sql.connect.catalog import Catalog
71+
from pyspark.sql.connect.udf import UDFRegistration
7172

7273

7374
class SparkSession:
@@ -436,8 +437,12 @@ def readStream(self) -> Any:
436437
raise NotImplementedError("readStream() is not implemented.")
437438

438439
@property
439-
def udf(self) -> Any:
440-
raise NotImplementedError("udf() is not implemented.")
440+
def udf(self) -> "UDFRegistration":
441+
from pyspark.sql.connect.udf import UDFRegistration
442+
443+
return UDFRegistration(self)
444+
445+
udf.__doc__ = PySparkSession.udf.__doc__
441446

442447
@property
443448
def version(self) -> str:

0 commit comments

Comments
 (0)