1919
2020package org .apache .comet .parquet
2121
22+ import java .math .{BigDecimal => JavaBigDecimal }
23+ import java .sql .{Date , Timestamp }
24+ import java .time .{Instant , LocalDate , LocalDateTime }
25+
2226import org .apache .spark .internal .Logging
27+ import org .apache .spark .sql .catalyst .util .DateTimeUtils
2328import org .apache .spark .sql .types ._
24- import org .apache .spark .unsafe .types .UTF8String
2529
2630import org .apache .comet .serde .ExprOuterClass
2731import org .apache .comet .serde .ExprOuterClass .Expr
2832import org .apache .comet .serde .QueryPlanSerde .serializeDataType
2933
3034object SourceFilterSerde extends Logging {
3135
32- def createNameExpr (name : String , schema : StructType ): Option [ExprOuterClass .Expr ] = {
36+ def createNameExpr (
37+ name : String ,
38+ schema : StructType ): Option [(DataType , ExprOuterClass .Expr )] = {
3339 val filedWithIndex = schema.fields.zipWithIndex.find { case (field, _) =>
3440 field.name == name
3541 }
@@ -43,6 +49,7 @@ object SourceFilterSerde extends Logging {
4349 .setDatatype(dataType.get)
4450 .build()
4551 Some (
52+ field.dataType,
4653 ExprOuterClass .Expr
4754 .newBuilder()
4855 .setBound(boundExpr)
@@ -56,72 +63,82 @@ object SourceFilterSerde extends Logging {
5663
5764 }
5865
59- def createValueExpr (value : Any ): Option [ExprOuterClass .Expr ] = {
66+ /**
67+ * create a literal value native expression for source filter value, the value is a scala value
68+ */
69+ def createValueExpr (value : Any , dataType : DataType ): Option [ExprOuterClass .Expr ] = {
6070 val exprBuilder = ExprOuterClass .Literal .newBuilder()
71+ var valueIsSet = true
6172 if (value == null ) {
6273 exprBuilder.setIsNull(true )
63- Some (ExprOuterClass .Expr .newBuilder().setLiteral(exprBuilder).build())
6474 } else {
6575 exprBuilder.setIsNull(false )
66- val dataType : Option [DataType ] = value match {
67- case v : Boolean =>
68- exprBuilder.setBoolVal(v)
69- Some (BooleanType )
70- case v : Byte =>
71- exprBuilder.setByteVal(v)
72- Some (ByteType )
73- case v : Short =>
74- exprBuilder.setShortVal(v)
75- Some (ShortType )
76- case v : Int =>
77- exprBuilder.setIntVal(v)
78- Some (IntegerType )
79- case v : Long =>
80- exprBuilder.setLongVal(v)
81- Some (LongType )
82- case v : Float =>
83- exprBuilder.setFloatVal(v)
84- Some (FloatType )
85- case v : Double =>
86- exprBuilder.setDoubleVal(v)
87- Some (DoubleType )
88- case v : UTF8String =>
89- exprBuilder.setStringVal(v.toString)
90- Some (StringType )
91- case v : Decimal =>
92- val unscaled = v.toBigDecimal.underlying.unscaledValue
76+ // value is a scala value, not a catalyst value
77+ // refer to org.apache.spark.sql.catalyst.CatalystTypeConverters.CatalystTypeConverter#toScala
78+ dataType match {
79+ case _ : BooleanType => exprBuilder.setBoolVal(value.asInstanceOf [Boolean ])
80+ case _ : ByteType => exprBuilder.setByteVal(value.asInstanceOf [Byte ])
81+ case _ : ShortType => exprBuilder.setShortVal(value.asInstanceOf [Short ])
82+ case _ : IntegerType => exprBuilder.setIntVal(value.asInstanceOf [Int ])
83+ case _ : LongType => exprBuilder.setLongVal(value.asInstanceOf [Long ])
84+ case _ : FloatType => exprBuilder.setFloatVal(value.asInstanceOf [Float ])
85+ case _ : DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf [Double ])
86+ case _ : StringType => exprBuilder.setStringVal(value.asInstanceOf [String ])
87+ case _ : TimestampType =>
88+ value match {
89+ case v : Timestamp => exprBuilder.setLongVal(DateTimeUtils .fromJavaTimestamp(v))
90+ case v : Instant => exprBuilder.setLongVal(DateTimeUtils .instantToMicros(v))
91+ case v : Long => exprBuilder.setLongVal(v)
92+ case _ =>
93+ valueIsSet = false
94+ logWarning(s " Unexpected timestamp type ' ${value.getClass}' for value ' $value' " )
95+ }
96+ case _ : TimestampNTZType =>
97+ value match {
98+ case v : LocalDateTime =>
99+ exprBuilder.setLongVal(DateTimeUtils .localDateTimeToMicros(v))
100+ case v : Long => exprBuilder.setLongVal(v)
101+ case _ =>
102+ valueIsSet = false
103+ logWarning(s " Unexpected timestamp type ' ${value.getClass}' for value' $value' " )
104+ }
105+ case _ : DecimalType =>
106+ // Pass decimal literal as bytes.
107+ val unscaled = value.asInstanceOf [JavaBigDecimal ].unscaledValue
93108 exprBuilder.setDecimalVal(com.google.protobuf.ByteString .copyFrom(unscaled.toByteArray))
94- Some (DecimalType (v.precision, v.scale))
95- case v : Array [Byte ] =>
109+ case _ : BinaryType =>
96110 val byteStr =
97- com.google.protobuf.ByteString .copyFrom(v )
111+ com.google.protobuf.ByteString .copyFrom(value. asInstanceOf [ Array [ Byte ]] )
98112 exprBuilder.setBytesVal(byteStr)
99- Some (BinaryType )
100- case v : java.sql.Date =>
101- exprBuilder.setIntVal(v.getTime.toInt)
102- Some (DateType )
103- case v : java.sql.Timestamp =>
104- exprBuilder.setLongVal(v.getTime)
105- Some (TimestampType )
106- case v : java.time.Instant =>
107- exprBuilder.setLongVal(v.toEpochMilli)
108- Some (TimestampType )
109- case _ =>
110- logWarning(s " Unsupported literal type: ${value.getClass}" )
111- None
112- }
113- if (dataType.isDefined) {
114- val dt = serializeDataType(dataType.get)
115- exprBuilder.setDatatype(dt.get)
116- Some (
117- ExprOuterClass .Expr
118- .newBuilder()
119- .setLiteral(exprBuilder)
120- .build())
121- } else {
122- None
113+ case _ : DateType =>
114+ value match {
115+ case v : LocalDate => exprBuilder.setIntVal(DateTimeUtils .localDateToDays(v))
116+ case v : Date => exprBuilder.setIntVal(DateTimeUtils .fromJavaDate(v))
117+ case v : Int => exprBuilder.setIntVal(v)
118+ case _ =>
119+ valueIsSet = false
120+ logWarning(s " Unexpected date type ' ${value.getClass}' for value ' $value' " )
121+ }
122+ exprBuilder.setIntVal(value.asInstanceOf [Int ])
123+ case dt =>
124+ valueIsSet = false
125+ logWarning(s " Unexpected data type ' $dt' for literal value ' $value' " )
123126 }
124127 }
128+
129+ val dt = serializeDataType(dataType)
130+
131+ if (valueIsSet && dt.isDefined) {
132+ exprBuilder.setDatatype(dt.get)
133+
134+ Some (
135+ ExprOuterClass .Expr
136+ .newBuilder()
137+ .setLiteral(exprBuilder)
138+ .build())
139+ } else {
140+ None
141+ }
125142 }
126143
127144 def createUnaryExpr (
0 commit comments