Skip to content

Introducing StateSchemaV3 for the TransformWithState operator #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState

/**
Expand All @@ -44,8 +44,9 @@ class ListStateImpl[S](

private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName)

store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false)
store.createColFamilyIfAbsent(columnFamilySchema)

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -52,11 +52,13 @@ class ListStateImplWithTTL[S](
private lazy val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true)
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true)
store.createColFamilyIfAbsent(columnFamilySchema)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.{BinaryType, StructType}

class MapStateImpl[K, V](
store: StateStore,
Expand All @@ -30,18 +30,15 @@ class MapStateImpl[K, V](
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType =
new StructType()
.add("key", BinaryType)
.add("userKey", BinaryType)
private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
private val keySerializer = keyExprEnc.createSerializer()
private val stateTypesEncoder = new CompositeKeyStateEncoder(
keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName)
keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName)

store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow,
PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1))
val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)

store.createColFamilyIfAbsent(columnFamilySchema)

/** Whether state exists or not. */
override def exists(): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -55,11 +55,13 @@ class MapStateImplWithTTL[K, V](
private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilySchema = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)
initialize()

private def initialize(): Unit = {
store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1))
store.createColFamilyIfAbsent(columnFamilySchema)
}

/** Whether state exists or not. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
metrics: Map[String, SQLMetric] = Map.empty,
existingColFamilies: Map[String, ColumnFamilySchemaV1] = Map.empty)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

Expand All @@ -97,6 +98,9 @@ class StatefulProcessorHandleImpl(
private[sql] val stateVariables: util.List[StateVariableInfo] =
new util.ArrayList[StateVariableInfo]()

private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] =
new util.ArrayList[ColumnFamilySchema]()

private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"

private def buildQueryInfo(): QueryInfo = {
Expand Down Expand Up @@ -128,13 +132,29 @@ class StatefulProcessorHandleImpl(

def getHandleState: StatefulProcessorHandleState = currState

def validateStateVariableCreation(newColumnFamilySchema: ColumnFamilySchemaV1): Unit = {
existingColFamilies.get(
newColumnFamilySchema.columnFamilyName).foreach { existingColFamily =>
// TODO: Fill in with conditions we need to validate new state variable creation
if (existingColFamily.json != newColumnFamilySchema.json) {
throw new RuntimeException(
s"State variable with name ${newColumnFamilySchema.columnFamilyName} already exists " +
s"with different schema. Existing schema: ${existingColFamily.json}, " +
s"New schema: ${newColumnFamilySchema.json}")
}
}
}

override def getValueState[T](
stateName: String,
valEncoder: Encoder[T]): ValueState[T] = {
verifyStateVarOperations("get_value_state")
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
incrementMetric("numValueStateVars")
val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder)
stateVariables.add(new StateVariableInfo(stateName, ValueState, false))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
resultState
}

Expand All @@ -143,15 +163,100 @@ class StatefulProcessorHandleImpl(
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ValueState[T] = {
verifyStateVarOperations("get_value_state")
stateVariables.add(new StateVariableInfo(stateName, ValueState, true))
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
val resultState = new ValueStateImplWithTTL[T](store, stateName,
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
incrementMetric("numValueStateWithTTLVars")
ttlStates.add(valueStateWithTTL)
valueStateWithTTL
ttlStates.add(resultState)
stateVariables.add(new StateVariableInfo(stateName, ValueState, true))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(colFamilySchema)
resultState
}

override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state")
incrementMetric("numListStateVars")
val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder)
stateVariables.add(new StateVariableInfo(stateName, ListState, false))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(resultState.columnFamilySchema)
resultState
}

/**
* Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any values in listState which have expired after ttlDuration will not
* returned on get() and will be eventually removed from the state.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ListState of type T that can be used to store state persistently
*/
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {

verifyStateVarOperations("get_list_state")
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val resultState = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
incrementMetric("numListStateWithTTLVars")
ttlStates.add(resultState)
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(resultState.columnFamilySchema)

resultState
}

override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]): MapState[K, V] = {
verifyStateVarOperations("get_map_state")
incrementMetric("numMapStateVars")
val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder)
stateVariables.add(new StateVariableInfo(stateName, MapState, false))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(resultState.columnFamilySchema)
resultState
}

override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state")
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val resultState = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
valEncoder, ttlConfig, batchTimestampMs.get)
incrementMetric("numMapStateWithTTLVars")
ttlStates.add(resultState)
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
val colFamilySchema = resultState.columnFamilySchema
validateStateVariableCreation(colFamilySchema)
columnFamilySchemas.add(resultState.columnFamilySchema)

resultState
}

override def getQueryInfo(): QueryInfo = currQueryInfo
Expand All @@ -163,6 +268,7 @@ class StatefulProcessorHandleImpl(
throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType,
currState.toString)
}

}

private def verifyTimerOperations(operationType: String): Unit = {
Expand Down Expand Up @@ -243,76 +349,6 @@ class StatefulProcessorHandleImpl(
}
}

override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = {
verifyStateVarOperations("get_list_state")
stateVariables.add(new StateVariableInfo(stateName, ListState, false))
incrementMetric("numListStateVars")
val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder)
resultState
}

/**
* Function to create new or return existing list state variable of given type
* with ttl. State values will not be returned past ttlDuration, and will be eventually removed
* from the state store. Any values in listState which have expired after ttlDuration will not
* returned on get() and will be eventually removed from the state.
*
* The user must ensure to call this function only within the `init()` method of the
* StatefulProcessor.
*
* @param stateName - name of the state variable
* @param valEncoder - SQL encoder for state variable
* @param ttlConfig - the ttl configuration (time to live duration etc.)
* @tparam T - type of state variable
* @return - instance of ListState of type T that can be used to store state persistently
*/
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {

verifyStateVarOperations("get_list_state")
stateVariables.add(new StateVariableInfo(stateName, ListState, true))
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get)
incrementMetric("numListStateWithTTLVars")
ttlStates.add(listStateWithTTL)

listStateWithTTL
}

override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]): MapState[K, V] = {
verifyStateVarOperations("get_map_state")
stateVariables.add(new StateVariableInfo(stateName, MapState, false))
incrementMetric("numMapStateVars")
val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder)
resultState
}

override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state")
stateVariables.add(new StateVariableInfo(stateName, MapState, true))
validateTTLConfig(ttlConfig, stateName)

assert(batchTimestampMs.isDefined)
val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc,
valEncoder, ttlConfig, batchTimestampMs.get)
incrementMetric("numMapStateWithTTLVars")
ttlStates.add(mapStateWithTTL)

mapStateWithTTL
}

private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = {
val ttlDuration = ttlConfig.ttlDuration
if (timeMode != TimeMode.ProcessingTime()) {
Expand Down
Loading
Loading