Skip to content

Commit

Permalink
[SPARK-49713][PYTHON][CONNECT] Make function count_min_sketch accep…
Browse files Browse the repository at this point in the history
…t number arguments

### What changes were proposed in this pull request?
1, Make function `count_min_sketch` accept number arguments;
2, Make argument `seed` optional;
3, fix the type hints of `eps/confidence/seed` from `ColumnOrName` to `Column`, because they require a foldable value and actually do not accept column name:
```
In [3]: from pyspark.sql import functions as sf

In [4]: df = spark.range(10000).withColumn("seed", sf.lit(1).cast("int"))

In [5]: df.select(sf.hex(sf.count_min_sketch("id", sf.lit(0.5), sf.lit(0.5), "seed")))
...
AnalysisException: [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "count_min_sketch(id, 0.5, 0.5, seed)" due to data type mismatch: the input `seed` should be a foldable "INT" expression; however, got "seed". SQLSTATE: 42K09;
'Aggregate [unresolvedalias('hex(count_min_sketch(id#1L, 0.5, 0.5, seed#2, 0, 0)))]
+- Project [id#1L, cast(1 as int) AS seed#2]
   +- Range (0, 10000, step=1, splits=Some(12))
...
```

### Why are the changes needed?
1, seed is optional in other similar functions;
2, existing type hint is `ColumnOrName` which is misleading since column name is not actually supported

### Does this PR introduce _any_ user-facing change?
yes, it support number arguments

### How was this patch tested?
updated doctests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48157 from zhengruifeng/py_fix_count_min_sketch.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng committed Sep 20, 2024
1 parent ca726c1 commit a5ac80a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 16 deletions.
10 changes: 6 additions & 4 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value
from pyspark.util import JVM_INT_MAX

# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
Expand Down Expand Up @@ -1126,11 +1127,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column:

def count_min_sketch(
col: "ColumnOrName",
eps: "ColumnOrName",
confidence: "ColumnOrName",
seed: "ColumnOrName",
eps: Union[Column, float],
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
_seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed)
return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)


count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__
Expand Down
71 changes: 59 additions & 12 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column:
@_try_remote_functions
def count_min_sketch(
col: "ColumnOrName",
eps: "ColumnOrName",
confidence: "ColumnOrName",
seed: "ColumnOrName",
eps: Union[Column, float],
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
"""
Returns a count-min sketch of a column with the given esp, confidence and seed.
Expand All @@ -6031,26 +6031,73 @@ def count_min_sketch(
----------
col : :class:`~pyspark.sql.Column` or str
target column to compute on.
eps : :class:`~pyspark.sql.Column` or str
eps : :class:`~pyspark.sql.Column` or float
relative error, must be positive
confidence : :class:`~pyspark.sql.Column` or str

.. versionchanged:: 4.0.0
`eps` now accepts float value.

confidence : :class:`~pyspark.sql.Column` or float
confidence, must be positive and less than 1.0
seed : :class:`~pyspark.sql.Column` or str

.. versionchanged:: 4.0.0
`confidence` now accepts float value.

seed : :class:`~pyspark.sql.Column` or int, optional
random seed

.. versionchanged:: 4.0.0
`seed` now accepts int value.

Returns
-------
:class:`~pyspark.sql.Column`
count-min sketch of the column

Examples
--------
>>> df = spark.createDataFrame([[1], [2], [1]], ['data'])
>>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch'))
>>> df.select(hex(df.sketch).alias('r')).collect()
[Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')]
"""
return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed)
Example 1: Using columns as arguments

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1)))
... ).show(truncate=False)
+------------------------------------------------------------------------+
|hex(count_min_sketch(id, 3.0, 0.1, 1)) |
+------------------------------------------------------------------------+
|0000000100000000000000640000000100000001000000005D8D6AB90000000000000064|
+------------------------------------------------------------------------+

Example 2: Using numbers as arguments

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2))
... ).show(truncate=False)
+----------------------------------------------------------------------------------------+
|hex(count_min_sketch(id, 1.0, 0.3, 2)) |
+----------------------------------------------------------------------------------------+
|0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032|
+----------------------------------------------------------------------------------------+

Example 3: Using a random seed

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6))
... ).show(truncate=False) # doctest: +SKIP
+----------------------------------------------------------------------------------------------------------------------------------------+
|hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) |
+----------------------------------------------------------------------------------------------------------------------------------------+
|0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032|
+----------------------------------------------------------------------------------------------------------------------------------------+
""" # noqa: E501
_eps = lit(eps)
_conf = lit(confidence)
if seed is None:
return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf)
else:
return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed))


@_try_remote_functions
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,18 @@ object functions {
def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column =
Column.fn("count_min_sketch", e, eps, confidence, seed)

/**
* Returns a count-min sketch of a column with the given esp, confidence and seed. The result is
* an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min
* sketch is a probabilistic data structure used for cardinality estimation using sub-linear
* space.
*
* @group agg_funcs
* @since 4.0.0
*/
def count_min_sketch(e: Column, eps: Column, confidence: Column): Column =
count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt))

private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

Expand Down

0 comments on commit a5ac80a

Please sign in to comment.