Skip to content

Commit

Permalink
[SPARK-49568][CONNECT][SQL] Remove self type from Dataset
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR removes the self type parameter from Dataset. This turned out to be a bit noisy. The self type is replaced by a combination of covariant return types and abstract types. Abstract types are used when a method takes a Dataset (or a KeyValueGroupedDataset) as an argument.

### Why are the changes needed?
The self type made using the classes in sql/api a bit noisy.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48146 from hvanhovell/SPARK-49568.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Herman van Hovell <herman@databricks.com>
  • Loading branch information
hvanhovell committed Sep 19, 2024
1 parent 5c48806 commit db8010b
Show file tree
Hide file tree
Showing 33 changed files with 500 additions and 364 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
import org.apache.spark.sql.connect.ConnectConversions._

/**
* Functionality for working with missing data in `DataFrame`s.
*
* @since 3.4.0
*/
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
extends api.DataFrameNaFunctions[Dataset] {
extends api.DataFrameNaFunctions {
import sparkSession.RichColumn

override protected def drop(minNonNulls: Option[Int]): Dataset[Row] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Stable
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.StructType

Expand All @@ -33,8 +34,8 @@ import org.apache.spark.sql.types.StructType
* @since 3.4.0
*/
@Stable
class DataFrameReader private[sql] (sparkSession: SparkSession)
extends api.DataFrameReader[Dataset] {
class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader {
type DS[U] = Dataset[U]

/** @inheritdoc */
override def format(source: String): this.type = super.format(source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.{lang => jl, util => ju}
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.functions.lit

/**
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit
* @since 3.4.0
*/
final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)
extends api.DataFrameStatFunctions[Dataset] {
extends api.DataFrameStatFunctions {
private def root: Relation = df.plan.getRoot
private val sparkSession: SparkSession = df.sparkSession

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.OrderUtils
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
Expand Down Expand Up @@ -134,8 +135,8 @@ class Dataset[T] private[sql] (
val sparkSession: SparkSession,
@DeveloperApi val plan: proto.Plan,
val encoder: Encoder[T])
extends api.Dataset[T, Dataset] {
type RGD = RelationalGroupedDataset
extends api.Dataset[T] {
type DS[U] = Dataset[U]

import sparkSession.RichColumn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.col
Expand All @@ -40,8 +41,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode
*
* @since 3.5.0
*/
class KeyValueGroupedDataset[K, V] private[sql] ()
extends api.KeyValueGroupedDataset[K, V, Dataset] {
class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] {
type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]

private def unsupported(): Nothing = throw new UnsupportedOperationException()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import scala.jdk.CollectionConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.ConnectConversions._

/**
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
Expand All @@ -39,8 +40,7 @@ class RelationalGroupedDataset private[sql] (
groupType: proto.Aggregate.GroupType,
pivot: Option[proto.Aggregate.Pivot] = None,
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
extends api.RelationalGroupedDataset[Dataset] {
type RGD = RelationalGroupedDataset
extends api.RelationalGroupedDataset {
import df.sparkSession.RichColumn

protected def toDF(aggExprs: Seq[Column]): DataFrame = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import org.apache.spark.util.ArrayImplicits._
class SparkSession private[sql] (
private[sql] val client: SparkConnectClient,
private val planIdGenerator: AtomicLong)
extends api.SparkSession[Dataset]
extends api.SparkSession
with Logging {

private[this] val allocator = new RootAllocator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalog
import java.util

import org.apache.spark.sql.{api, DataFrame, Dataset}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.types.StructType

/** @inheritdoc */
abstract class Catalog extends api.Catalog[Dataset] {
abstract class Catalog extends api.Catalog {

/** @inheritdoc */
override def listDatabases(): Dataset[Database]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect

import scala.language.implicitConversions

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql._

/**
* Conversions from sql interfaces to the Connect specific implementation.
*
* This class is mainly used by the implementation. In the case of connect it should be extremely
* rare that a developer needs these classes.
*
* We provide both a trait and an object. The trait is useful in situations where an extension
* developer needs to use these conversions in a project covering multiple Spark versions. They
* can create a shim for these conversions, the Spark 4+ version of the shim implements this
* trait, and shims for older versions do not.
*/
@DeveloperApi
trait ConnectConversions {
implicit def castToImpl(session: api.SparkSession): SparkSession =
session.asInstanceOf[SparkSession]

implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] =
ds.asInstanceOf[Dataset[T]]

implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset =
rgds.asInstanceOf[RelationalGroupedDataset]

implicit def castToImpl[K, V](
kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] =
kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]
}

object ConnectConversions extends ConnectConversions
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.StreamingQueryCommand
import org.apache.spark.connect.proto.StreamingQueryCommandResult
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.sql.{api, Dataset, SparkSession}
import org.apache.spark.sql.{api, SparkSession}

/** @inheritdoc */
trait StreamingQuery extends api.StreamingQuery[Dataset] {
trait StreamingQuery extends api.StreamingQuery {

/** @inheritdoc */
override def sparkSession: SparkSession
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"),
Expand Down
2 changes: 2 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ object MimaExcludes {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"),
// DSv2 catalog and expression APIs are unstable yet. We should enable this back.
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"),
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,7 @@ trait SharedUnidocSettings {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalog/v2/utils")))
Expand Down
Loading

0 comments on commit db8010b

Please sign in to comment.