Skip to content

Commit

Permalink
Extract source table names from mv query (opensearch-project#854)
Browse files Browse the repository at this point in the history
* add sourceTables to MV index metadata properties

Signed-off-by: Sean Kao <seankao@amazon.com>

* parse source tables from mv query

Signed-off-by: Sean Kao <seankao@amazon.com>

* test cases for parse source tables from mv query

Signed-off-by: Sean Kao <seankao@amazon.com>

* use constant for metadata cache version

Signed-off-by: Sean Kao <seankao@amazon.com>

* write source tables to metadata cache

Signed-off-by: Sean Kao <seankao@amazon.com>

* address comment

Signed-off-by: Sean Kao <seankao@amazon.com>

* generate source tables for old mv without new prop

Signed-off-by: Sean Kao <seankao@amazon.com>

* syntax fix

Signed-off-by: Sean Kao <seankao@amazon.com>

---------

Signed-off-by: Sean Kao <seankao@amazon.com>
  • Loading branch information
seankao-az authored and 14yapkc1 committed Dec 11, 2024
1 parent 3cdc2d9 commit ea1fe54
Show file tree
Hide file tree
Showing 11 changed files with 369 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
attachLatestLogEntry(indexName, metadata)
}
.toList
.flatMap(FlintSparkIndexFactory.create)
.flatMap(metadata => FlintSparkIndexFactory.create(spark, metadata))
} else {
Seq.empty
}
Expand All @@ -202,7 +202,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
if (flintClient.exists(indexName)) {
val metadata = flintIndexMetadataService.getIndexMetadata(indexName)
val metadataWithEntry = attachLatestLogEntry(indexName, metadata)
FlintSparkIndexFactory.create(metadataWithEntry)
FlintSparkIndexFactory.create(spark, metadataWithEntry)
} else {
Option.empty
}
Expand Down Expand Up @@ -327,7 +327,7 @@ class FlintSpark(val spark: SparkSession) extends FlintSparkTransactionSupport w
val index = describeIndex(indexName)

if (index.exists(_.options.autoRefresh())) {
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(index.get).get
val updatedIndex = FlintSparkIndexFactory.createWithDefaultOptions(spark, index.get).get
FlintSparkIndexRefresh
.create(updatedIndex.name(), updatedIndex)
.validate(spark)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ abstract class FlintSparkIndexBuilder(flint: FlintSpark) {
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
validateIndex(FlintSparkIndexFactory.create(updatedMetadata).get)
validateIndex(FlintSparkIndexFactory.create(flint.spark, updatedMetadata).get)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy
import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession

/**
* Flint Spark index factory that encapsulates specific Flint index instance creation. This is for
Expand All @@ -35,14 +36,16 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index from generic Flint metadata.
*
* @param spark
* Spark session
* @param metadata
* Flint metadata
* @return
* Flint index instance, or None if any error during creation
*/
def create(metadata: FlintMetadata): Option[FlintSparkIndex] = {
def create(spark: SparkSession, metadata: FlintMetadata): Option[FlintSparkIndex] = {
try {
Some(doCreate(metadata))
Some(doCreate(spark, metadata))
} catch {
case e: Exception =>
logWarning(s"Failed to create Flint index from metadata $metadata", e)
Expand All @@ -53,24 +56,26 @@ object FlintSparkIndexFactory extends Logging {
/**
* Creates Flint index with default options.
*
* @param spark
* Spark session
* @param index
* Flint index
* @param metadata
* Flint metadata
* @return
* Flint index with default options
*/
def createWithDefaultOptions(index: FlintSparkIndex): Option[FlintSparkIndex] = {
def createWithDefaultOptions(
spark: SparkSession,
index: FlintSparkIndex): Option[FlintSparkIndex] = {
val originalOptions = index.options
val updatedOptions =
FlintSparkIndexOptions.updateOptionsWithDefaults(index.name(), originalOptions)
val updatedMetadata = index
.metadata()
.copy(options = updatedOptions.options.mapValues(_.asInstanceOf[AnyRef]).asJava)
this.create(updatedMetadata)
this.create(spark, updatedMetadata)
}

private def doCreate(metadata: FlintMetadata): FlintSparkIndex = {
private def doCreate(spark: SparkSession, metadata: FlintMetadata): FlintSparkIndex = {
val indexOptions = FlintSparkIndexOptions(
metadata.options.asScala.mapValues(_.asInstanceOf[String]).toMap)
val latestLogEntry = metadata.latestLogEntry
Expand Down Expand Up @@ -118,6 +123,7 @@ object FlintSparkIndexFactory extends Logging {
FlintSparkMaterializedView(
metadata.name,
metadata.source,
getMvSourceTables(spark, metadata),
metadata.indexedColumns.map { colInfo =>
getString(colInfo, "columnName") -> getString(colInfo, "columnType")
}.toMap,
Expand All @@ -134,6 +140,15 @@ object FlintSparkIndexFactory extends Logging {
.toMap
}

private def getMvSourceTables(spark: SparkSession, metadata: FlintMetadata): Array[String] = {
val sourceTables = getArrayString(metadata.properties, "sourceTables")
if (sourceTables.isEmpty) {
FlintSparkMaterializedView.extractSourceTableNames(spark, metadata.source)
} else {
sourceTables
}
}

private def getString(map: java.util.Map[String, AnyRef], key: String): String = {
map.get(key).asInstanceOf[String]
}
Expand All @@ -146,4 +161,12 @@ object FlintSparkIndexFactory extends Logging {
Some(value.asInstanceOf[String])
}
}

private def getArrayString(map: java.util.Map[String, AnyRef], key: String): Array[String] = {
map.get(key) match {
case list: java.util.ArrayList[_] =>
list.toArray.map(_.toString)
case _ => Array.empty[String]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.flint.{loadTable, parseTableName, qualifyTableName}
import org.apache.spark.sql.flint.{loadTable, parseTableName}

/**
* Flint Spark validation helper.
Expand All @@ -31,16 +30,10 @@ trait FlintSparkValidationHelper extends Logging {
* true if all non Hive, otherwise false
*/
def isTableProviderSupported(spark: SparkSession, index: FlintSparkIndex): Boolean = {
// Extract source table name (possibly more than one for MV query)
val tableNames = index match {
case skipping: FlintSparkSkippingIndex => Seq(skipping.tableName)
case covering: FlintSparkCoveringIndex => Seq(covering.tableName)
case mv: FlintSparkMaterializedView =>
spark.sessionState.sqlParser
.parsePlan(mv.query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
case mv: FlintSparkMaterializedView => mv.sourceTables.toSeq
}

// Validate if any source table is not supported (currently Hive only)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter
import org.opensearch.flint.common.metadata.FlintMetadata
import org.opensearch.flint.common.metadata.log.FlintMetadataLogEntry
import org.opensearch.flint.spark.FlintSparkIndexOptions
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.opensearch.flint.spark.scheduler.util.IntervalSchedulerParser

/**
Expand Down Expand Up @@ -46,9 +47,7 @@ case class FlintMetadataCache(

object FlintMetadataCache {

// TODO: constant for version
val mockTableName =
"dataSourceName.default.logGroups(logGroupIdentifier:['arn:aws:logs:us-east-1:123456:test-llt-xa', 'arn:aws:logs:us-east-1:123456:sample-lg-1'])"
val metadataCacheVersion = "1.0"

def apply(metadata: FlintMetadata): FlintMetadataCache = {
val indexOptions = FlintSparkIndexOptions(
Expand All @@ -61,14 +60,22 @@ object FlintMetadataCache {
} else {
None
}
val sourceTables = metadata.kind match {
case MV_INDEX_TYPE =>
metadata.properties.get("sourceTables") match {
case list: java.util.ArrayList[_] =>
list.toArray.map(_.toString)
case _ => Array.empty[String]
}
case _ => Array(metadata.source)
}
val lastRefreshTime: Option[Long] = metadata.latestLogEntry.flatMap { entry =>
entry.lastRefreshCompleteTime match {
case FlintMetadataLogEntry.EMPTY_TIMESTAMP => None
case timestamp => Some(timestamp)
}
}

// TODO: get source tables from metadata
FlintMetadataCache("1.0", refreshInterval, Array(mockTableName), lastRefreshTime)
FlintMetadataCache(metadataCacheVersion, refreshInterval, sourceTables, lastRefreshTime)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* MV name
* @param query
* source query that generates MV data
* @param sourceTables
* source table names
* @param outputSchema
* output schema
* @param options
Expand All @@ -44,6 +46,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class FlintSparkMaterializedView(
mvName: String,
query: String,
sourceTables: Array[String],
outputSchema: Map[String, String],
override val options: FlintSparkIndexOptions = empty,
override val latestLogEntry: Option[FlintMetadataLogEntry] = None)
Expand All @@ -64,6 +67,7 @@ case class FlintSparkMaterializedView(
metadataBuilder(this)
.name(mvName)
.source(query)
.addProperty("sourceTables", sourceTables)
.indexedColumns(indexColumnMaps)
.schema(schema)
.build()
Expand Down Expand Up @@ -165,10 +169,30 @@ object FlintSparkMaterializedView {
flintIndexNamePrefix(mvName)
}

/**
* Extract source table names (possibly more than one) from the query.
*
* @param spark
* Spark session
* @param query
* source query that generates MV data
* @return
* source table names
*/
def extractSourceTableNames(spark: SparkSession, query: String): Array[String] = {
spark.sessionState.sqlParser
.parsePlan(query)
.collect { case relation: UnresolvedRelation =>
qualifyTableName(spark, relation.tableName)
}
.toArray
}

/** Builder class for MV build */
class Builder(flint: FlintSpark) extends FlintSparkIndexBuilder(flint) {
private var mvName: String = ""
private var query: String = ""
private var sourceTables: Array[String] = Array.empty[String]

/**
* Set MV name.
Expand All @@ -193,6 +217,7 @@ object FlintSparkMaterializedView {
*/
def query(query: String): Builder = {
this.query = query
this.sourceTables = extractSourceTableNames(flint.spark, query)
this
}

Expand Down Expand Up @@ -221,7 +246,7 @@ object FlintSparkMaterializedView {
field.name -> field.dataType.simpleString
}
.toMap
FlintSparkMaterializedView(mvName, query, outputSchema, indexOptions)
FlintSparkMaterializedView(mvName, query, sourceTables, outputSchema, indexOptions)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import org.opensearch.flint.core.storage.FlintOpenSearchIndexMetadataService
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView
import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE
import org.scalatest.matchers.should.Matchers._

import org.apache.spark.FlintSuite

class FlintSparkIndexFactorySuite extends FlintSuite {

test("create mv should generate source tables if missing in metadata") {
val testTable = "spark_catalog.default.mv_build_test"
val testMvName = "spark_catalog.default.mv"
val testQuery = s"SELECT * FROM $testTable"

val content =
s""" {
| "_meta": {
| "kind": "$MV_INDEX_TYPE",
| "indexedColumns": [
| {
| "columnType": "int",
| "columnName": "age"
| }
| ],
| "name": "$testMvName",
| "source": "$testQuery"
| },
| "properties": {
| "age": {
| "type": "integer"
| }
| }
| }
|""".stripMargin

val metadata = FlintOpenSearchIndexMetadataService.deserialize(content)
val index = FlintSparkIndexFactory.create(spark, metadata)
index shouldBe defined
index.get
.asInstanceOf[FlintSparkMaterializedView]
.sourceTables should contain theSameElementsAs Array(testTable)
}
}
Loading

0 comments on commit ea1fe54

Please sign in to comment.