Skip to content

Commit

Permalink
[SPARK-5550] [SQL] Support the case insensitive for UDF
Browse files Browse the repository at this point in the history
SQL in HiveContext, should be case insensitive, however, the following query will fail.

```scala
udf.register("random0", ()  => { Math.random()})
assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
```

Author: Cheng Hao <hao.cheng@intel.com>

Closes apache#4326 from chenghao-intel/udf_case_sensitive and squashes the following commits:

485cf66 [Cheng Hao] Support the case insensitive for UDF
  • Loading branch information
chenghao-intel authored and marmbrus committed Feb 3, 2015
1 parent 0c20ce6 commit ca7a6cd
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,25 @@ trait FunctionRegistry {
def registerFunction(name: String, builder: FunctionBuilder): Unit

def lookupFunction(name: String, children: Seq[Expression]): Expression

def caseSensitive: Boolean
}

trait OverrideFunctionRegistry extends FunctionRegistry {

val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)

def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
}

abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children))
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children))
}
}

class SimpleFunctionRegistry extends FunctionRegistry {
val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistry {
val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive)

def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
Expand All @@ -64,4 +66,30 @@ object EmptyFunctionRegistry extends FunctionRegistry {
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}

def caseSensitive: Boolean = ???
}

/**
* Build a map with String type of key, and it also supports either key case
* sensitive or insensitive.
* TODO move this into util folder?
*/
object StringKeyHashMap {
def apply[T](caseSensitive: Boolean) = caseSensitive match {
case false => new StringKeyHashMap[T](_.toLowerCase)
case true => new StringKeyHashMap[T](identity)
}
}

class StringKeyHashMap[T](normalizer: (String) => String) {
private val base = new collection.mutable.HashMap[String, T]()

def apply(key: String): T = base(normalizer(key))

def get(key: String): Option[T] = base.get(normalizer(key))
def put(key: String, value: T): Option[T] = base.put(normalizer(key), value)
def remove(key: String): Option[T] = base.remove(normalizer(key))
def iterator: Iterator[(String, T)] = base.toIterator
}

Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true)

@transient
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)

@transient
protected[sql] lazy val analyzer: Analyzer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
// Note that HiveUDFs will be overridden by functions registered in this context.
@transient
override protected[sql] lazy val functionRegistry =
new HiveFunctionRegistry with OverrideFunctionRegistry
new HiveFunctionRegistry with OverrideFunctionRegistry {
def caseSensitive = false
}

/* An analyzer that uses the Hive metastore. */
@transient
Expand Down
36 changes: 36 additions & 0 deletions sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive

/* Implicits */

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.hive.test.TestHive._

case class FunctionResult(f1: String, f2: String)

class UDFSuite extends QueryTest {
test("UDF case insensitive") {
udf.register("random0", () => { Math.random()})
udf.register("RANDOM1", () => { Math.random()})
udf.register("strlenScala", (_: String).length + (_:Int))
assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0)
assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5)
}
}

0 comments on commit ca7a6cd

Please sign in to comment.