Skip to content

[SPARK-48770][SS] Change to read operator metadata once on driver to check if we can find info for numColsPrefixKey used for session window agg queries #47167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration

/**
* An implementation of [[TableProvider]] with [[DataSourceRegister]] for State Store data source.
Expand All @@ -46,6 +48,8 @@ class StateDataSource extends TableProvider with DataSourceRegister {

private lazy val hadoopConf: Configuration = session.sessionState.newHadoopConf()

private lazy val serializedHadoopConf = new SerializableConfiguration(hadoopConf)

override def shortName(): String = "statestore"

override def getTable(
Expand All @@ -54,7 +58,17 @@ class StateDataSource extends TableProvider with DataSourceRegister {
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.apply(session, hadoopConf, properties)
val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, sourceOptions.batchId)
new StateTable(session, schema, sourceOptions, stateConf)
// Read the operator metadata once to see if we can find the information for prefix scan
// encoder used in session window aggregation queries.
val allStateStoreMetadata = new StateMetadataPartitionReader(
sourceOptions.stateCheckpointLocation.getParent.toString, serializedHadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == sourceOptions.operatorId &&
entry.stateStoreName == sourceOptions.storeName
}

new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata)
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, ReadStateStore, StateStoreConf, StateStoreId, StateStoreProvider, StateStoreProviderId}
import org.apache.spark.sql.types.StructType
Expand All @@ -33,11 +33,12 @@ import org.apache.spark.util.SerializableConfiguration
class StatePartitionReaderFactory(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
schema: StructType) extends PartitionReaderFactory {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends PartitionReaderFactory {

override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
new StatePartitionReader(storeConf, hadoopConf,
partition.asInstanceOf[StateStoreInputPartition], schema)
partition.asInstanceOf[StateStoreInputPartition], schema, stateStoreMetadata)
}
}

Expand All @@ -49,7 +50,9 @@ class StatePartitionReader(
storeConf: StateStoreConf,
hadoopConf: SerializableConfiguration,
partition: StateStoreInputPartition,
schema: StructType) extends PartitionReader[InternalRow] with Logging {
schema: StructType,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends PartitionReader[InternalRow] with Logging {

private val keySchema = SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType]
private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, "value").asInstanceOf[StructType]
Expand All @@ -58,13 +61,6 @@ class StatePartitionReader(
val stateStoreId = StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
partition.sourceOptions.operatorId, partition.partition, partition.sourceOptions.storeName)
val stateStoreProviderId = StateStoreProviderId(stateStoreId, partition.queryId)
val allStateStoreMetadata = new StateMetadataPartitionReader(
partition.sourceOptions.stateCheckpointLocation.getParent.toString, hadoopConf)
.stateMetadata.toArray
val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
entry.operatorId == partition.sourceOptions.operatorId &&
entry.stateStoreName == partition.sourceOptions.storeName
}
val numColsPrefixKey = if (stateStoreMetadata.isEmpty) {
logWarning("Metadata for state store not found, possible cause is this checkpoint " +
"is created by older version of spark. If the query has session window aggregation, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.StructType
Expand All @@ -35,8 +36,10 @@ class StateScanBuilder(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf)
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions, stateStoreConf,
stateStoreMetadata)
}

/** An implementation of [[InputPartition]] for State Store data source. */
Expand All @@ -50,7 +53,8 @@ class StateScan(
session: SparkSession,
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf) extends Scan with Batch {
stateStoreConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with Batch {

// A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
private val hadoopConfBroadcast = session.sparkContext.broadcast(
Expand All @@ -62,7 +66,8 @@ class StateScan(
val fs = stateCheckpointPartitionsLocation.getFileSystem(hadoopConfBroadcast.value.value)
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new PathFilter() {
override def accept(path: Path): Boolean = {
fs.isDirectory(path) && Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
fs.getFileStatus(path).isDirectory &&
Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
}
})

Expand Down Expand Up @@ -105,7 +110,8 @@ class StateScan(
hadoopConfBroadcast.value, userFacingSchema, stateSchema)

case JoinSideValues.none =>
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema)
new StatePartitionReaderFactory(stateStoreConf, hadoopConfBroadcast.value, schema,
stateStoreMetadata)
}

override def toBatch: Batch = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsMetadataColumns, SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
import org.apache.spark.sql.execution.streaming.state.StateStoreConf
import org.apache.spark.sql.types.{IntegerType, StructType}
Expand All @@ -35,7 +36,8 @@ class StateTable(
session: SparkSession,
override val schema: StructType,
sourceOptions: StateSourceOptions,
stateConf: StateStoreConf)
stateConf: StateStoreConf,
stateStoreMetadata: Array[StateMetadataTableEntry])
extends Table with SupportsRead with SupportsMetadataColumns {

import StateTable._
Expand Down Expand Up @@ -64,7 +66,7 @@ class StateTable(
override def capabilities(): util.Set[TableCapability] = CAPABILITY

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new StateScanBuilder(session, schema, sourceOptions, stateConf)
new StateScanBuilder(session, schema, sourceOptions, stateConf, stateStoreMetadata)

override def properties(): util.Map[String, String] = Map.empty[String, String].asJava

Expand Down