Skip to content

Commit 396c0e1

Browse files
avoid Simple UDF to be serialized
1 parent e9c3212 commit 396c0e1

File tree

1 file changed

+32
-21
lines changed

1 file changed

+32
-21
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,39 +86,50 @@ class HiveFunctionCache(var functionClassName: String) extends java.io.Externali
8686
private var instance: Any = null
8787

8888
def writeExternal(out: java.io.ObjectOutput) {
89-
// Some of the UDF are serializable, but some others are not
90-
// Hive Utilities can handle both cases
91-
val baos = new java.io.ByteArrayOutputStream()
92-
HiveShim.serializePlan(instance, baos)
93-
val functionInBytes = baos.toByteArray
94-
9589
// output the function name
9690
out.writeUTF(functionClassName)
9791

98-
// output the function bytes
99-
out.writeInt(functionInBytes.length)
100-
out.write(functionInBytes, 0, functionInBytes.length)
92+
// Write a flag if instance is null or not
93+
out.writeBoolean(instance != null)
94+
if (instance != null) {
95+
// Some of the UDF are serializable, but some others are not
96+
// Hive Utilities can handle both cases
97+
val baos = new java.io.ByteArrayOutputStream()
98+
HiveShim.serializePlan(instance, baos)
99+
val functionInBytes = baos.toByteArray
100+
101+
// output the function bytes
102+
out.writeInt(functionInBytes.length)
103+
out.write(functionInBytes, 0, functionInBytes.length)
104+
}
101105
}
102106

103107
def readExternal(in: java.io.ObjectInput) {
104108
// read the function name
105109
functionClassName = in.readUTF()
106110

107-
// read the function in bytes
108-
val functionInBytesLength = in.readInt()
109-
val functionInBytes = new Array[Byte](functionInBytesLength)
110-
in.read(functionInBytes, 0, functionInBytesLength)
111+
if (in.readBoolean()) {
112+
// if the instance is not null
113+
// read the function in bytes
114+
val functionInBytesLength = in.readInt()
115+
val functionInBytes = new Array[Byte](functionInBytesLength)
116+
in.read(functionInBytes, 0, functionInBytesLength)
111117

112-
// deserialize the function object via Hive Utilities
113-
instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes),
114-
getContextOrSparkClassLoader.loadClass(functionClassName))
118+
// deserialize the function object via Hive Utilities
119+
instance = HiveShim.deserializePlan(new java.io.ByteArrayInputStream(functionInBytes),
120+
getContextOrSparkClassLoader.loadClass(functionClassName))
121+
}
115122
}
116123

117-
def createFunction[UDFType]() = {
118-
if (instance == null) {
119-
instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance
124+
def createFunction[UDFType](alwaysCreateNewInstance: Boolean = false) = {
125+
if (alwaysCreateNewInstance) {
126+
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
127+
} else {
128+
if (instance == null) {
129+
instance = getContextOrSparkClassLoader.loadClass(functionClassName).newInstance
130+
}
131+
instance.asInstanceOf[UDFType]
120132
}
121-
instance.asInstanceOf[UDFType]
122133
}
123134
}
124135

@@ -130,7 +141,7 @@ private[hive] case class HiveSimpleUdf(cache: HiveFunctionCache, children: Seq[E
130141
def nullable = true
131142

132143
@transient
133-
lazy val function = cache.createFunction[UDFType]()
144+
lazy val function = cache.createFunction[UDFType](true) // Simple UDF should be not serialized.
134145

135146
@transient
136147
protected lazy val method =

0 commit comments

Comments
 (0)