Skip to content

Commit 4caad5d

Browse files
committed
fix in expr
1 parent 919c54d commit 4caad5d

File tree

2 files changed

+93
-77
lines changed

2 files changed

+93
-77
lines changed

spark/src/main/scala/org/apache/comet/parquet/ParquetFilters.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ class ParquetFilters(
896896
def nameUnaryExpr(name: String)(
897897
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder)
898898
: Option[ExprOuterClass.Expr] = {
899-
createNameExpr(name, dataSchema).map { childExpr =>
899+
createNameExpr(name, dataSchema).map { case (_, childExpr) =>
900900
createUnaryExpr(childExpr, f)
901901
}
902902
}
@@ -906,10 +906,8 @@ class ParquetFilters(
906906
ExprOuterClass.Expr.Builder,
907907
ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
908908
: Option[ExprOuterClass.Expr] = {
909-
(createNameExpr(name, dataSchema), createValueExpr(value)) match {
910-
case (Some(nameExpr), Some(valueExpr)) =>
911-
Some(createBinaryExpr(nameExpr, valueExpr, f))
912-
case _ => None
909+
createNameExpr(name, dataSchema).flatMap { case (dataType, childExpr) =>
910+
createValueExpr(value, dataType).map(createBinaryExpr(childExpr, _, f))
913911
}
914912
}
915913

@@ -996,20 +994,21 @@ class ParquetFilters(
996994
case sources.In(name, values)
997995
if pushDownInFilterThreshold > 0 && values.nonEmpty &&
998996
canMakeFilterOn(name, values.head) =>
999-
val nameExpr = createNameExpr(name, dataSchema)
1000-
val valueExprs = values.flatMap(createValueExpr)
1001-
if (nameExpr.isEmpty || valueExprs.length != values.length) {
1002-
None
1003-
} else {
1004-
val builder = ExprOuterClass.In.newBuilder()
1005-
builder.setInValue(nameExpr.get)
1006-
builder.addAllLists(valueExprs.toSeq.asJava)
1007-
builder.setNegated(false)
1008-
Some(
1009-
ExprOuterClass.Expr
1010-
.newBuilder()
1011-
.setIn(builder)
1012-
.build())
997+
createNameExpr(name, dataSchema).flatMap { case (dataType, nameExpr) =>
998+
val valueExprs = values.flatMap(createValueExpr(_, dataType))
999+
if (valueExprs.length != values.length) {
1000+
None
1001+
} else {
1002+
val builder = ExprOuterClass.In.newBuilder()
1003+
builder.setInValue(nameExpr)
1004+
builder.addAllLists(valueExprs.toSeq.asJava)
1005+
builder.setNegated(false)
1006+
Some(
1007+
ExprOuterClass.Expr
1008+
.newBuilder()
1009+
.setIn(builder)
1010+
.build())
1011+
}
10131012
}
10141013

10151014
case sources.StringStartsWith(name, prefix)

spark/src/main/scala/org/apache/comet/parquet/SourceFilterSerde.scala

Lines changed: 75 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@
1919

2020
package org.apache.comet.parquet
2121

22+
import java.math.{BigDecimal => JavaBigDecimal}
23+
import java.sql.{Date, Timestamp}
24+
import java.time.{Instant, LocalDate, LocalDateTime}
25+
2226
import org.apache.spark.internal.Logging
27+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
2328
import org.apache.spark.sql.types._
24-
import org.apache.spark.unsafe.types.UTF8String
2529

2630
import org.apache.comet.serde.ExprOuterClass
2731
import org.apache.comet.serde.ExprOuterClass.Expr
2832
import org.apache.comet.serde.QueryPlanSerde.serializeDataType
2933

3034
object SourceFilterSerde extends Logging {
3135

32-
def createNameExpr(name: String, schema: StructType): Option[ExprOuterClass.Expr] = {
36+
def createNameExpr(
37+
name: String,
38+
schema: StructType): Option[(DataType, ExprOuterClass.Expr)] = {
3339
val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) =>
3440
field.name == name
3541
}
@@ -43,6 +49,7 @@ object SourceFilterSerde extends Logging {
4349
.setDatatype(dataType.get)
4450
.build()
4551
Some(
52+
field.dataType,
4653
ExprOuterClass.Expr
4754
.newBuilder()
4855
.setBound(boundExpr)
@@ -56,72 +63,82 @@ object SourceFilterSerde extends Logging {
5663

5764
}
5865

59-
def createValueExpr(value: Any): Option[ExprOuterClass.Expr] = {
66+
/**
67+
* create a literal value native expression for source filter value, the value is a scala value
68+
*/
69+
def createValueExpr(value: Any, dataType: DataType): Option[ExprOuterClass.Expr] = {
6070
val exprBuilder = ExprOuterClass.Literal.newBuilder()
71+
var valueIsSet = true
6172
if (value == null) {
6273
exprBuilder.setIsNull(true)
63-
Some(ExprOuterClass.Expr.newBuilder().setLiteral(exprBuilder).build())
6474
} else {
6575
exprBuilder.setIsNull(false)
66-
val dataType: Option[DataType] = value match {
67-
case v: Boolean =>
68-
exprBuilder.setBoolVal(v)
69-
Some(BooleanType)
70-
case v: Byte =>
71-
exprBuilder.setByteVal(v)
72-
Some(ByteType)
73-
case v: Short =>
74-
exprBuilder.setShortVal(v)
75-
Some(ShortType)
76-
case v: Int =>
77-
exprBuilder.setIntVal(v)
78-
Some(IntegerType)
79-
case v: Long =>
80-
exprBuilder.setLongVal(v)
81-
Some(LongType)
82-
case v: Float =>
83-
exprBuilder.setFloatVal(v)
84-
Some(FloatType)
85-
case v: Double =>
86-
exprBuilder.setDoubleVal(v)
87-
Some(DoubleType)
88-
case v: UTF8String =>
89-
exprBuilder.setStringVal(v.toString)
90-
Some(StringType)
91-
case v: Decimal =>
92-
val unscaled = v.toBigDecimal.underlying.unscaledValue
76+
// value is a scala value, not a catalyst value
77+
// refer to org.apache.spark.sql.catalyst.CatalystTypeConverters.CatalystTypeConverter#toScala
78+
dataType match {
79+
case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
80+
case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte])
81+
case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short])
82+
case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int])
83+
case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long])
84+
case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float])
85+
case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double])
86+
case _: StringType => exprBuilder.setStringVal(value.asInstanceOf[String])
87+
case _: TimestampType =>
88+
value match {
89+
case v: Timestamp => exprBuilder.setLongVal(DateTimeUtils.fromJavaTimestamp(v))
90+
case v: Instant => exprBuilder.setLongVal(DateTimeUtils.instantToMicros(v))
91+
case v: Long => exprBuilder.setLongVal(v)
92+
case _ =>
93+
valueIsSet = false
94+
logWarning(s"Unexpected timestamp type '${value.getClass}' for value '$value'")
95+
}
96+
case _: TimestampNTZType =>
97+
value match {
98+
case v: LocalDateTime =>
99+
exprBuilder.setLongVal(DateTimeUtils.localDateTimeToMicros(v))
100+
case v: Long => exprBuilder.setLongVal(v)
101+
case _ =>
102+
valueIsSet = false
103+
logWarning(s"Unexpected timestamp type '${value.getClass}' for value' $value'")
104+
}
105+
case _: DecimalType =>
106+
// Pass decimal literal as bytes.
107+
val unscaled = value.asInstanceOf[JavaBigDecimal].unscaledValue
93108
exprBuilder.setDecimalVal(com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
94-
Some(DecimalType(v.precision, v.scale))
95-
case v: Array[Byte] =>
109+
case _: BinaryType =>
96110
val byteStr =
97-
com.google.protobuf.ByteString.copyFrom(v)
111+
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
98112
exprBuilder.setBytesVal(byteStr)
99-
Some(BinaryType)
100-
case v: java.sql.Date =>
101-
exprBuilder.setIntVal(v.getTime.toInt)
102-
Some(DateType)
103-
case v: java.sql.Timestamp =>
104-
exprBuilder.setLongVal(v.getTime)
105-
Some(TimestampType)
106-
case v: java.time.Instant =>
107-
exprBuilder.setLongVal(v.toEpochMilli)
108-
Some(TimestampType)
109-
case _ =>
110-
logWarning(s"Unsupported literal type: ${value.getClass}")
111-
None
112-
}
113-
if (dataType.isDefined) {
114-
val dt = serializeDataType(dataType.get)
115-
exprBuilder.setDatatype(dt.get)
116-
Some(
117-
ExprOuterClass.Expr
118-
.newBuilder()
119-
.setLiteral(exprBuilder)
120-
.build())
121-
} else {
122-
None
113+
case _: DateType =>
114+
value match {
115+
case v: LocalDate => exprBuilder.setIntVal(DateTimeUtils.localDateToDays(v))
116+
case v: Date => exprBuilder.setIntVal(DateTimeUtils.fromJavaDate(v))
117+
case v: Int => exprBuilder.setIntVal(v)
118+
case _ =>
119+
valueIsSet = false
120+
logWarning(s"Unexpected date type '${value.getClass}' for value '$value'")
121+
}
122+
exprBuilder.setIntVal(value.asInstanceOf[Int])
123+
case dt =>
124+
valueIsSet = false
125+
logWarning(s"Unexpected data type '$dt' for literal value '$value'")
123126
}
124127
}
128+
129+
val dt = serializeDataType(dataType)
130+
131+
if (valueIsSet && dt.isDefined) {
132+
exprBuilder.setDatatype(dt.get)
133+
134+
Some(
135+
ExprOuterClass.Expr
136+
.newBuilder()
137+
.setLiteral(exprBuilder)
138+
.build())
139+
} else {
140+
None
141+
}
125142
}
126143

127144
def createUnaryExpr(

0 commit comments

Comments
 (0)