Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add filter push down #1642

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ import java.util
import org.opensearch.flint.core.{FlintClientBuilder, FlintOptions}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.flint.storage.FlintQueryCompiler
import org.apache.spark.sql.types.StructType

case class FlintPartitionReaderFactory(
tableName: String,
schema: StructType,
properties: util.Map[String, String])
properties: util.Map[String, String],
pushedPredicates: Array[Predicate])
extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val query = FlintQueryCompiler(schema).compile(pushedPredicates)
val flintClient = FlintClientBuilder.build(new FlintOptions(properties))
new FlintPartitionReader(flintClient.createReader(tableName, ""), schema)
new FlintPartitionReader(flintClient.createReader(tableName, query), schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@ package org.apache.spark.sql.flint

import java.util

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory, Scan}
import org.apache.spark.sql.types.StructType

case class FlintScan(tableName: String, schema: StructType, properties: util.Map[String, String])
case class FlintScan(
tableName: String,
schema: StructType,
properties: util.Map[String, String],
pushedPredicates: Array[Predicate])
extends Scan
with Batch {

Expand All @@ -21,10 +26,19 @@ case class FlintScan(tableName: String, schema: StructType, properties: util.Map
}

override def createReaderFactory(): PartitionReaderFactory = {
FlintPartitionReaderFactory(tableName, schema, properties)
FlintPartitionReaderFactory(tableName, schema, properties, pushedPredicates)
}

override def toBatch: Batch = this

/**
* Print pushedPredicates when explain(mode="extended"). Learn from SPARK JDBCScan.
*/
override def description(): String = {
super.description() + ", PushedPredicates: " + seqToString(pushedPredicates)
}

private def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
}

// todo. add partition support.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import java.util

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownV2Filters}
import org.apache.spark.sql.flint.storage.FlintQueryCompiler
import org.apache.spark.sql.types.StructType

case class FlintScanBuilder(
Expand All @@ -18,9 +20,21 @@ case class FlintScanBuilder(
schema: StructType,
properties: util.Map[String, String])
extends ScanBuilder
with SupportsPushDownV2Filters
with Logging {

private var pushedPredicate = Array.empty[Predicate]

override def build(): Scan = {
FlintScan(tableName, schema, properties)
FlintScan(tableName, schema, properties, pushedPredicate)
}

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unSupported) =
predicates.partition(FlintQueryCompiler(schema).compile(_).nonEmpty)
pushedPredicate = pushed
unSupported
}

override def pushedPredicates(): Array[Predicate] = pushedPredicate
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.flint.storage

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter.{And, Predicate}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

/**
* Todo. find the right package.
*/
case class FlintQueryCompiler(schema: StructType) {

/**
* Using AND to concat predicates. Todo. If spark spark.sql.ansi.enabled = true, more expression
* defined in V2ExpressionBuilder could be pushed down.
*/
def compile(predicates: Array[Predicate]): String = {
if (predicates.isEmpty) {
return ""
}
compile(predicates.reduce(new And(_, _)))
penghuo marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Compile Expression to Flint query string.
*
* @param expr
* Expression.
* @return
* empty if does not support.
*/
def compile(expr: Expression, quoteString: Boolean = true): String = {
expr match {
case LiteralValue(value, dataType) =>
if (quoteString && dataType == StringType) {
s""""${Literal(value, dataType).toString()}""""
} else {
Literal(value, dataType).toString()
}
case p: Predicate => visitPredicate(p)
case f: FieldReference => f.toString()
case _ => ""
}
}

/**
* Predicate is defined in SPARK filters.scala. Todo.
* 1. currently, we map spark contains to OpenSearch match query. Can we leverage more full
* text queries for text field. 2. configuration of expensive query.
*/
def visitPredicate(p: Predicate): String = {
val name = p.name()
name match {
case "IS_NULL" =>
s"""{"bool":{"must_not":{"exists":{"field":"${compile(p.children()(0))}"}}}}"""
penghuo marked this conversation as resolved.
Show resolved Hide resolved
case "IS_NOT_NULL" =>
s"""{"exists":{"field":"${compile(p.children()(0))}"}}"""
case "AND" =>
s"""{"bool":{"filter":[${compile(p.children()(0))},${compile(p.children()(1))}]}}"""
case "OR" =>
s"""{"bool":{"should":[{"bool":{"filter":${compile(
p.children()(0))}}},{"bool":{"filter":${compile(p.children()(1))}}}]}}"""
case "NOT" =>
s"""{"bool":{"must_not":${compile(p.children()(0))}}}"""
case "=" =>
s"""{"term":{"${compile(p.children()(0))}":{"value":${compile(p.children()(1))}}}}"""
case ">" =>
s"""{"range":{"${compile(p.children()(0))}":{"gt":${compile(p.children()(1))}}}}"""
case ">=" =>
s"""{"range":{"${compile(p.children()(0))}":{"gte":${compile(p.children()(1))}}}}"""
case "<" =>
s"""{"range":{"${compile(p.children()(0))}":{"lt":${compile(p.children()(1))}}}}"""
case "<=" =>
s"""{"range":{"${compile(p.children()(0))}":{"lte":${compile(p.children()(1))}}}}"""
case "IN" =>
val values = p.children().tail.map(expr => compile(expr)).mkString("[", ",", "]")
s"""{"terms":{"${compile(p.children()(0))}":$values}}"""
case "STARTS_WITH" =>
s"""{"prefix":{"${compile(p.children()(0))}":{"value":${compile(p.children()(1))}}}}"""
case "CONTAINS" =>
val fieldName = compile(p.children()(0))
if (isTextField(fieldName)) {
s"""{"match":{"$fieldName":{"query":${compile(p.children()(1))}}}}"""
} else {
s"""{"wildcard":{"$fieldName":{"value":"*${compile(p.children()(1), false)}*"}}}"""
}
case "ENDS_WITH" =>
s"""{"wildcard":{"${compile(p.children()(0))}":{"value":"*${compile(
p.children()(1),
false)}"}}}"""
case _ => ""
}
}

/**
* return true if the field is Flint Text field.
*/
protected def isTextField(attribute: String): Boolean = {
schema.apply(attribute) match {
case StructField(_, StringType, _, metadata) =>
metadata.contains("osType") && metadata.getString("osType") == "text"
case _ => false
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.spark.sql.flint.storage

import org.apache.spark.FlintSuite
import org.apache.spark.sql.connector.expressions.{FieldReference, GeneralScalarExpression}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

class FlintQueryCompilerSuit extends FlintSuite {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np: Suit -> Suite

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


test("compile a list of expressions should successfully") {
val query =
FlintQueryCompiler(schema()).compile(
Array(EqualTo("aInt", 1).toV2, EqualTo("aString", "s").toV2))
assertResult(
"""{"bool":{"filter":[{"term":{"aInt":{"value":1}}},{"term":{"aString":{"value":"s"}}}]}}""")(
query)
}

test("compile a list of expressions contain one expression should successfully") {
val query =
FlintQueryCompiler(schema()).compile(Array(EqualTo("aInt", 1).toV2))
assertResult("""{"term":{"aInt":{"value":1}}}""")(query)
}

test("compile a empty list of expressions should return empty") {
val query =
FlintQueryCompiler(schema()).compile(Array.empty[Predicate])
assert(query.isEmpty)
}

test("compile unsupported expression abs(aInt) should return empty string") {
val query = FlintQueryCompiler(schema()).compile(
// SPARK V2ExpressionBuilder define the expression.
new GeneralScalarExpression("ABS", Array(FieldReference.apply("aInt"))))
assert(query.isEmpty)
}

test("compile and(aInt=1, aString=s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(And(EqualTo("aInt", 1), EqualTo("aString", "s")).toV2)
assertResult(
"""{"bool":{"filter":[{"term":{"aInt":{"value":1}}},{"term":{"aString":{"value":"s"}}}]}}""")(
query)
}

test("compile or(aInt=1, aString=s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(Or(EqualTo("aInt", 1), EqualTo("aString", "s")).toV2)
// scalastyle:off
assertResult(
"""{"bool":{"should":[{"bool":{"filter":{"term":{"aInt":{"value":1}}}}},{"bool":{"filter":{"term":{"aString":{"value":"s"}}}}}]}}""")(
query)
// scalastyle:on
}

test("compile and(aInt>1, aString>s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(
And(GreaterThan("aInt", 1), GreaterThan("aString", "s")).toV2)
assertResult(
"""{"bool":{"filter":[{"range":{"aInt":{"gt":1}}},{"range":{"aString":{"gt":"s"}}}]}}""")(
query)
}

test("compile and(aInt>=1, aString>=s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(
And(GreaterThanOrEqual("aInt", 1), GreaterThanOrEqual("aString", "s")).toV2)
assertResult(
"""{"bool":{"filter":[{"range":{"aInt":{"gte":1}}},{"range":{"aString":{"gte":"s"}}}]}}""")(
query)
}

test("compile and(aInt<1, aString<s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(
And(LessThan("aInt", 1), LessThan("aString", "s")).toV2)
assertResult(
"""{"bool":{"filter":[{"range":{"aInt":{"lt":1}}},{"range":{"aString":{"lt":"s"}}}]}}""")(
query)
}

test("compile and(aInt<=1, aString<=s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(
And(LessThanOrEqual("aInt", 1), LessThanOrEqual("aString", "s")).toV2)
assertResult(
"""{"bool":{"filter":[{"range":{"aInt":{"lte":1}}},{"range":{"aString":{"lte":"s"}}}]}}""")(
query)
}

test("compile aInt IN (1, 2, 3) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(In("aInt", Array(1, 2, 3)).toV2)
assertResult("""{"terms":{"aInt":[1,2,3]}}""")(query)
}

test("compile STARTS_WITH(aString, s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(StringStartsWith("aString", "s").toV2)
assertResult("""{"prefix":{"aString":{"value":"s"}}}""")(query)
}

test("compile CONTAINS(aString, s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(StringContains("aString", "s").toV2)
assertResult("""{"wildcard":{"aString":{"value":"*s*"}}}""")(query)
}

test("compile CONTAINS(aText, s) should use match query") {
val query =
FlintQueryCompiler(schema()).compile(StringContains("aText", "s").toV2)
assertResult("""{"match":{"aText":{"query":"s"}}}""")(query)
}

test("compile ENDS_WITH(aString, s) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(StringEndsWith("aString", "s").toV2)
assertResult("""{"wildcard":{"aString":{"value":"*s"}}}""")(query)
}

test("compile IS_NULL(aString) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(IsNull("aString").toV2)
assertResult("""{"bool":{"must_not":{"exists":{"field":"aString"}}}}""")(query)
}

test("compile IS_NOT_NULL(aString) should successfully") {
val query =
FlintQueryCompiler(schema()).compile(IsNotNull("aString").toV2)
assertResult("""{"exists":{"field":"aString"}}""")(query)
}

protected def schema(): StructType = {
StructType(
Seq(
StructField("aString", StringType, nullable = true),
StructField("aInt", IntegerType, nullable = true),
StructField(
"aText",
StringType,
nullable = true,
new MetadataBuilder().putString("osType", "text").build())))
}
}
Loading