Skip to content

[SPARK-38932][SQL] Datasource v2 support report distinct keys #36253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.read;

import java.util.Set;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* A mix in interface for {@link Scan}. Data sources can implement this interface to
* report unique keys set to Spark.
* <p>
* Spark will optimize the query plan according to the given unique keys.
* For example, Spark will eliminate the `Distinct` if the v2 relation only output the unique
* attributes.
* <pre>
* Distinct
* +- RelationV2[unique_key#1]
* </pre>
* <p>
* Note that, Spark doest not validate whether the value is unique or not. The implementation
* should guarantee this.
*
* @since 3.4.0
*/
@Evolving
public interface SupportsReportDistinctKeys extends Scan {
/**
* Returns a set of unique keys. Each unique keys can consist of multiple attributes.
*/
Set<Set<NamedReference>> distinctKeysSet();
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
refs.map(ref => resolveRef[T](ref, plan))
}

def resolveRefs[T <: NamedExpression](refs: Set[NamedReference], plan: LogicalPlan): Set[T] = {
refs.map(ref => resolveRef[T](ref, plan))
}

/**
* Converts the array of input V2 [[V2SortOrder]] into their counterparts in catalyst.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ object DistinctKeyVisitor extends LogicalPlanVisitor[Set[ExpressionSet]] {
}
}

override def default(p: LogicalPlan): Set[ExpressionSet] = Set.empty[ExpressionSet]
override def default(p: LogicalPlan): Set[ExpressionSet] = p match {
case leaf: LeafNode => leaf.reportDistinctKeysSet()
case _ => Set.empty[ExpressionSet]
}

override def visitAggregate(p: Aggregate): Set[ExpressionSet] = {
// handle group by a, a and global aggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ abstract class LogicalPlan
trait LeafNode extends LogicalPlan with LeafLike[LogicalPlan] {
override def producedAttributes: AttributeSet = outputSet

/** Return a set of unique keys. */
def reportDistinctKeysSet(): Set[ExpressionSet] = Set.empty[ExpressionSet]

/** Leaf nodes that can survive analysis must define their own statistics. */
def computeStats(): Statistics = throw new UnsupportedOperationException
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExpressionSet, SortOrder, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportDistinctKeys, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -75,6 +77,13 @@ case class DataSourceV2Relation(
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $qualifiedTableName $name"
}

override def reportDistinctKeysSet(): Set[ExpressionSet] = {
table.asReadable.newScanBuilder(options).build() match {
case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
case _ => super.reportDistinctKeysSet()
}
}

override def computeStats(): Statistics = {
if (Utils.isTesting) {
// when testing, throw an exception if this computeStats method is called because stats should
Expand Down Expand Up @@ -134,6 +143,11 @@ case class DataSourceV2ScanRelation(
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name"
}

override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match {
case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
case _ => super.reportDistinctKeysSet()
}

override def computeStats(): Statistics = {
scan match {
case r: SupportsReportStatistics =>
Expand Down Expand Up @@ -166,6 +180,11 @@ case class StreamingDataSourceV2Relation(

override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))

override def reportDistinctKeysSet(): Set[ExpressionSet] = scan match {
case r: SupportsReportDistinctKeys => DataSourceV2Relation.transformUniqueKeysSet(r, this)
case _ => super.reportDistinctKeysSet()
}

override def computeStats(): Statistics = scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
Expand Down Expand Up @@ -220,4 +239,13 @@ object DataSourceV2Relation {
sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
rowCount = numRows)
}

def transformUniqueKeysSet(
r: SupportsReportDistinctKeys,
p: LogicalPlan): Set[ExpressionSet] = {
val uniqueKeysSet = r.distinctKeysSet().asScala
uniqueKeysSet.map { uniqueKeys =>
ExpressionSet(V2ExpressionUtils.resolveRefs(uniqueKeys.asScala.toSet, p))
}.toSet
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.util.OptionalLong

import scala.collection.mutable

import com.google.common.collect.Sets
import org.scalatest.Assertions._

import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -272,10 +273,22 @@ class InMemoryTable(
var data: Seq[InputPartition],
readSchema: StructType,
tableSchema: StructType)
extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning {
extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning
with SupportsReportDistinctKeys {

override def toBatch: Batch = this

override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = {
val uniqueKeys = readSchema.fields.collect {
case f if f.metadata.contains("unique") => f.name
} .map(FieldReference(_))
.map(Sets.newHashSet(_))

Sets.newHashSet(
uniqueKeys: _*
)
}

override def estimateStatistics(): Statistics = {
if (data.isEmpty) {
return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql.connector;

import com.google.common.collect.Sets;
import org.apache.spark.sql.connector.TestingV2Source;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.expressions.FieldReference;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.read.*;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

import java.util.Set;

public class JavaReportDistinctKeysDataSource implements TestingV2Source {
static class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportDistinctKeys {
@Override
public Set<Set<NamedReference>> distinctKeysSet() {
return Sets.newHashSet(
Sets.newHashSet(FieldReference.apply("i")),
Sets.newHashSet(FieldReference.apply("j")));
}

@Override
public InputPartition[] planInputPartitions() {
InputPartition[] partitions = new InputPartition[1];
partitions[0] = new JavaRangeInputPartition(0, 1);
return partitions;
}
}

@Override
public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new JavaReportDistinctKeysDataSource.MyScanBuilder();
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import java.util.Collections

import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{IntegerType, Metadata, StructField, StructType}
import org.apache.spark.sql.util.QueryExecutionListener

class DataSourceV2DataFrameSuite
Expand Down Expand Up @@ -253,4 +253,22 @@ class DataSourceV2DataFrameSuite
spark.listenerManager.unregister(listener)
}
}

test("SPARK-38932: Datasource v2 support report distinct keys") {
val t = "testcat.unique.t"
withTable(t) {
val unique = """ {"unique":""} """.stripMargin
val schema = StructType(
StructField("key", IntegerType, metadata = Metadata.fromJson(unique)) ::
StructField("value", IntegerType) :: Nil)
val data = spark.sparkContext.parallelize(Row(1, 1) :: Row(2, 1) :: Nil)
spark.createDataFrame(data, schema).writeTo(t).create()

val qe = spark.table(t).groupBy($"key").agg($"key").queryExecution
val analyzed = qe.analyzed
val optimized = qe.optimizedPlan
assert(analyzed.exists(_.isInstanceOf[Aggregate]))
assert(!optimized.exists(_.isInstanceOf[Aggregate]))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ package org.apache.spark.sql.connector
import java.io.File
import java.util.OptionalLong

import com.google.common.collect.Sets
import test.org.apache.spark.sql.connector._

import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.ExpressionSet
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, NamedReference, NullOrdering, SortDirection, SortOrder, Transform}
Expand Down Expand Up @@ -596,6 +599,30 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
}

test("SPARK-38932: Datasource v2 support report distinct keys") {
def checkUniqueKeys(leaf: LeafNode): Unit = {
// Assume all output attributes are unique keys
val expected1 = leaf.output.map(attr => ExpressionSet(attr :: Nil)).toSet
assert(leaf.reportDistinctKeysSet() == expected1)
}

Seq(classOf[ReportUniqueKeysDataSource], classOf[JavaReportDistinctKeysDataSource]).foreach {
cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
val analyzed = df.queryExecution.analyzed.collect {
case d: DataSourceV2Relation => d
}.head
checkUniqueKeys(analyzed)

val optimized = df.queryExecution.optimizedPlan.collect {
case d: DataSourceV2ScanRelation => d
}.head
checkUniqueKeys(optimized)
}
}
}
}


Expand Down Expand Up @@ -1106,3 +1133,26 @@ class ReportStatisticsDataSource extends SimpleWritableDataSource {
}
}
}

class ReportUniqueKeysDataSource extends SimpleWritableDataSource {

class MyScanBuilder extends SimpleScanBuilder with SupportsReportDistinctKeys {
override def distinctKeysSet(): java.util.Set[java.util.Set[NamedReference]] = {
Sets.newHashSet(
Sets.newHashSet(FieldReference.apply("i")),
Sets.newHashSet(FieldReference.apply("j")))
}

override def planInputPartitions(): Array[InputPartition] = {
Array(RangeInputPartition(0, 1))
}
}

override def getTable(options: CaseInsensitiveStringMap): Table = {
new SimpleBatchTable {
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder
}
}
}
}