@@ -235,26 +235,39 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
235235 * [[InSet (value, HashSet[Literal])]] which is much faster.
236236 */
237237object OptimizeIn extends Rule [LogicalPlan ] {
238+ def optimizeIn (expr : In , v : Expression , list : Seq [Expression ]): Expression = {
239+ val newList = ExpressionSet (list).toSeq
240+ if (newList.length == 1
241+ // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
242+ // TODO: we exclude them in this rule.
243+ && ! v.isInstanceOf [CreateNamedStruct ]
244+ && ! newList.head.isInstanceOf [CreateNamedStruct ]) {
245+ EqualTo (v, newList.head)
246+ } else if (newList.length > SQLConf .get.optimizerInSetConversionThreshold) {
247+ val hSet = newList.map(e => e.eval(EmptyRow ))
248+ InSet (v, HashSet () ++ hSet)
249+ } else if (newList.length < list.length) {
250+ expr.copy(list = newList)
251+ } else { // newList.length == list.length && newList.length > 1
252+ expr
253+ }
254+ }
255+
238256 def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
239257 case q : LogicalPlan => q transformExpressionsDown {
240258 case In (v, list) if list.isEmpty =>
241259 // When v is not nullable, the following expression will be optimized
242260 // to FalseLiteral which is tested in OptimizeInSuite.scala
243261 If (IsNotNull (v), FalseLiteral , Literal (null , BooleanType ))
244- case expr @ In (v, list) if expr.inSetConvertible =>
245- val newList = ExpressionSet (list).toSeq
246- if (newList.length == 1
247- // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed,
248- // TODO: we exclude them in this rule.
249- && ! v.isInstanceOf [CreateNamedStruct ]
250- && ! newList.head.isInstanceOf [CreateNamedStruct ]) {
251- EqualTo (v, newList.head)
252- } else if (newList.length > SQLConf .get.optimizerInSetConversionThreshold) {
253- val hSet = newList.map(e => e.eval(EmptyRow ))
254- InSet (v, HashSet () ++ hSet)
255- } else if (newList.length < list.length) {
256- expr.copy(list = newList)
257- } else { // newList.length == list.length && newList.length > 1
262+ case expr @ In (v, list) =>
263+ // split list to 2 parts so that we can push down convertible part
264+ val (convertible, nonConvertible) = list.partition(_.isInstanceOf [Literal ])
265+ if (convertible.nonEmpty && nonConvertible.isEmpty) {
266+ optimizeIn(expr, v, list)
267+ } else if (convertible.nonEmpty && nonConvertible.nonEmpty) {
268+ val optimizedIn = optimizeIn(In (v, convertible), v, convertible)
269+ And (optimizedIn, In (v, nonConvertible))
270+ } else {
258271 expr
259272 }
260273 }
0 commit comments