Skip to content

Commit

Permalink
[SPARK-49422][CONNECT][SQL] Create a shared interface for KeyValueGro…
Browse files Browse the repository at this point in the history
…upedDataset

### What changes were proposed in this pull request?
This PR creates a shared interface for KeyValueGroupedDataset.

### Why are the changes needed?
We are creating a shared Scala Spark SQL interface for Classic and Connect.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47960 from hvanhovell/SPARK-49422.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Sep 3, 2024
1 parent e49dfcb commit 8301f92
Show file tree
Hide file tree
Showing 12 changed files with 1,646 additions and 1,809 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter, UdfUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.{struct, to_json}
import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.internal.{ColumnNodeToProtoConverter, DataFrameWriterImpl, ToScalaUDF, UDFAdaptors, UnresolvedAttribute, UnresolvedRegex}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.{Metadata, StructType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -534,7 +534,7 @@ class Dataset[T] private[sql] (
* @since 3.5.0
*/
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
groupByKey(UdfUtils.mapFunctionToScalaFunc(func))(encoder)
groupByKey(ToScalaUDF(func))(encoder)

/** @inheritdoc */
@scala.annotation.varargs
Expand Down Expand Up @@ -865,17 +865,17 @@ class Dataset[T] private[sql] (

/** @inheritdoc */
def filter(f: FilterFunction[T]): Dataset[T] = {
filter(UdfUtils.filterFuncToScalaFunc(f))
filter(ToScalaUDF(f))
}

/** @inheritdoc */
def map[U: Encoder](f: T => U): Dataset[U] = {
mapPartitions(UdfUtils.mapFuncToMapPartitionsAdaptor(f))
mapPartitions(UDFAdaptors.mapToMapPartitions(f))
}

/** @inheritdoc */
def map[U](f: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
map(UdfUtils.mapFunctionToScalaFunc(f))(encoder)
mapPartitions(UDFAdaptors.mapToMapPartitions(f))(encoder)
}

/** @inheritdoc */
Expand All @@ -892,25 +892,11 @@ class Dataset[T] private[sql] (
}
}

/** @inheritdoc */
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
mapPartitions(UdfUtils.mapPartitionsFuncToScalaFunc(f))(encoder)
}

/** @inheritdoc */
override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
mapPartitions(UdfUtils.flatMapFuncToMapPartitionsAdaptor(func))

/** @inheritdoc */
override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
flatMap(UdfUtils.flatMapFuncToScalaFunc(f))(encoder)
}

/** @inheritdoc */
@deprecated("use flatMap() or select() with functions.explode() instead", "3.5.0")
def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = {
val generator = SparkUserDefinedFunction(
UdfUtils.iterableOnceToSeq(f),
UDFAdaptors.iterableOnceToSeq(f),
UnboundRowEncoder :: Nil,
ScalaReflection.encoderFor[Seq[A]])
select(col("*"), functions.inline(generator(struct(input: _*))))
Expand All @@ -921,31 +907,16 @@ class Dataset[T] private[sql] (
def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)(
f: A => IterableOnce[B]): DataFrame = {
val generator = SparkUserDefinedFunction(
UdfUtils.iterableOnceToSeq(f),
UDFAdaptors.iterableOnceToSeq(f),
Nil,
ScalaReflection.encoderFor[Seq[B]])
select(col("*"), functions.explode(generator(col(inputColumn))).as((outputColumn)))
}

/** @inheritdoc */
def foreach(f: T => Unit): Unit = {
foreachPartition(UdfUtils.foreachFuncToForeachPartitionsAdaptor(f))
}

/** @inheritdoc */
override def foreach(func: ForeachFunction[T]): Unit =
foreach(UdfUtils.foreachFuncToScalaFunc(func))

/** @inheritdoc */
def foreachPartition(f: Iterator[T] => Unit): Unit = {
// Delegate to mapPartition with empty result.
mapPartitions(UdfUtils.foreachPartitionFuncToMapPartitionsAdaptor(f))(RowEncoder(Seq.empty))
.collect()
}

/** @inheritdoc */
override def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
foreachPartition(UdfUtils.foreachPartitionFuncToScalaFunc(func))
mapPartitions(UDFAdaptors.foreachPartitionToMapPartitions(f))(NullEncoder).collect()
}

/** @inheritdoc */
Expand Down Expand Up @@ -1464,6 +1435,22 @@ class Dataset[T] private[sql] (
override def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] =
super.dropDuplicatesWithinWatermark(col1, cols: _*)

/** @inheritdoc */
override def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] =
super.mapPartitions(f, encoder)

/** @inheritdoc */
override def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] =
super.flatMap(func)

/** @inheritdoc */
override def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
super.flatMap(f, encoder)

/** @inheritdoc */
override def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
super.foreachPartition(func)

/** @inheritdoc */
@scala.annotation.varargs
override def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] =
Expand Down
Loading

0 comments on commit 8301f92

Please sign in to comment.