Skip to content

Commit

Permalink
[SPARK-47395] Add collate and collation to other APIs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Added `collate` and `collation` functions to:
- Scala API
- Python API
- R API
- Spark Connect Scala Client
- Spark Connect Python Client

### Why are the changes needed?

In order to access these collation capabilities from non-sql apis

### Does this PR introduce _any_ user-facing change?

Yes, users can now access these functions from non-sql apis

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45517 from stefankandic/collateInDFAPI.

Authored-by: Stefan Kandic <stefan.kandic@databricks.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
stefankandic authored and MaxGekk committed Mar 15, 2024
1 parent 6bf0317 commit e2c0471
Show file tree
Hide file tree
Showing 17 changed files with 212 additions and 0 deletions.
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ exportMethods("%<=>%",
"cbrt",
"ceil",
"ceiling",
"collate",
"collation",
"collect_list",
"collect_set",
"column",
Expand Down
29 changes: 29 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,35 @@ setMethod("upper",
column(jc)
})

#' @details
#' \code{collate}: Marks a given column with specified collation.
#'
#' @param x a Column on which to perform collate.
#' @param collation specified collation name.
#' @rdname column_string_functions
#' @aliases collate collate,Column-method
#' @note collate since 4.0.0
setMethod("collate",
signature(x = "Column", collation = "character"),
function(x, collation) {
jc <- callJStatic("org.apache.spark.sql.functions", "collate", x@jc, collation)
column(jc)
})

#' @details
#' \code{collation}: Returns the collation name of a given column.
#'
#' @param x a Column on which to return collation name.
#' @rdname column_string_functions
#' @aliases collation collation,Column-method
#' @note collation since 4.0.0
setMethod("collation",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "collation", x@jc)
column(jc)
})

#' @details
#' \code{var}: Alias for \code{var_samp}.
#'
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
#' @name NULL
setGeneric("ceil", function(x) { standardGeneric("ceil") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("collate", function(x, collation) { standardGeneric("collate") })

#' @rdname column_string_functions
#' @name NULL
setGeneric("collation", function(x) { standardGeneric("collation") })

#' @rdname column_aggregate_functions
#' @name NULL
setGeneric("collect_list", function(x) { standardGeneric("collect_list") })
Expand Down
2 changes: 2 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,8 @@ test_that("column functions", {
c31 <- sec(c1) + csc(c1) + cot(c1)
c32 <- ln(c1) + positive(c2) + negative(c3)
c33 <- width_bucket(lit(2.5), lit(2.0), lit(3.0), lit(10L))
c34 <- collate(c, "UNICODE")
c35 <- collation(c)

# Test if base::is.nan() is exposed
expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4845,6 +4845,22 @@ object functions {
otherChar: Column): Column =
Column.fn("mask", input, upperChar, lowerChar, digitChar, otherChar)

/**
* Marks a given column with specified collation.
*
* @group string_funcs
* @since 4.0.0
*/
def collate(e: Column, collation: String): Column = Column.fn("collate", e, lit(collation))

/**
* Returns the collation name of a given column.
*
* @group string_funcs
* @since 4.0.0
*/
def collation(e: Column): Column = Column.fn("collation", e)

//////////////////////////////////////////////////////////////////////////////////////////////
// DateTime functions
//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,14 @@ class PlanGenerationTestSuite
fn.hours(Column("a"))
}

functionTest("collate") {
fn.collate(fn.col("g"), "UNICODE")
}

functionTest("collation") {
fn.collation(fn.col("g"))
}

temporalFunctionTest("convert_timezone with source time zone") {
fn.convert_timezone(lit("\"Africa/Dakar\""), lit("\"Asia/Urumqi\""), fn.col("t"))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [collate(g#0, UNICODE) AS collate(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [UTF8_BINARY AS collation(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "collate",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}, {
"literal": {
"string": "UNICODE"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "collation",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}]
}
}]
}
}
Binary file not shown.
2 changes: 2 additions & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ String Functions
char
char_length
character_length
collate
collation
concat_ws
contains
decode
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2853,6 +2853,20 @@ def mask(
mask.__doc__ = pysparkfuncs.mask.__doc__


def collate(col: "ColumnOrName", collation: str) -> Column:
return _invoke_function("collate", _to_col(col), lit(collation))


collate.__doc__ = pysparkfuncs.collate.__doc__


def collation(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns("collation", col)


collation.__doc__ = pysparkfuncs.collation.__doc__


# Date/Timestamp functions


Expand Down
52 changes: 52 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12579,6 +12579,58 @@ def mask(
)


@_try_remote_functions
def collate(col: "ColumnOrName", collation: str) -> Column:
"""
Marks a given column with specified collation.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Target string column to work on.
collation : str
Target collation name.

Returns
-------
:class:`~pyspark.sql.Column`
A new column of string type, where each value has the specified collation.
"""
return _invoke_function("collate", _to_java_column(col), collation)


@_try_remote_functions
def collation(col: "ColumnOrName") -> Column:
"""
Returns the collation name of a given column.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
Target string column to work on.

Returns
-------
:class:`~pyspark.sql.Column`
collation name of a given expression.

Examples
--------
>>> df = spark.createDataFrame([('name',)], ['dt'])
>>> df.select(collation('dt').alias('collation')).show()
+-----------+
| collation|
+-----------+
|UTF8_BINARY|
+-----------+
"""
return _invoke_function_over_columns("collation", col)


# ---------------------- Collection functions ------------------------------


Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,11 @@ def test_string_functions(self):
df.select(getattr(F, name)(F.col("name"))).first()[0],
)

def test_collation(self):
df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct().collect()
self.assertEqual([Row("UNICODE")], actual)

def test_octet_length_function(self):
# SPARK-36751: add octet length api for python
df = self.spark.createDataFrame([("cat",), ("\U0001F408",)], ["cat"])
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3944,6 +3944,22 @@ object functions {
*/
def octet_length(e: Column): Column = Column.fn("octet_length", e)

/**
* Marks a given column with specified collation.
*
* @group string_funcs
* @since 4.0.0
*/
def collate(e: Column, collation: String): Column = Column.fn("collate", e, lit(collation))

/**
* Returns the collation name of a given column.
*
* @group string_funcs
* @since 4.0.0
*/
def collation(e: Column): Column = Column.fn("collation", e)

/**
* Returns true if `str` matches `regexp`, or false otherwise.
*
Expand Down

0 comments on commit e2c0471

Please sign in to comment.