Skip to content

Commit d632135

Browse files
committed
Merge remote-tracking branch 'upstream/master' into cv_min
2 parents 16e3b2c + 1b6fe9b commit d632135

File tree

121 files changed

+2093
-230
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

121 files changed

+2093
-230
lines changed

R/pkg/NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ export("sparkR.init")
1010
export("sparkR.stop")
1111
export("print.jobj")
1212

13+
# Job group lifecycle management methods
14+
export("setJobGroup",
15+
"clearJobGroup",
16+
"cancelJobGroup")
17+
1318
exportClasses("DataFrame")
1419

1520
exportMethods("arrange",

R/pkg/R/sparkR.R

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,47 @@ sparkRHive.init <- function(jsc = NULL) {
278278
assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
279279
hiveCtx
280280
}
281+
282+
#' Assigns a group ID to all the jobs started by this thread until the group ID is set to a
283+
#' different value or cleared.
284+
#'
285+
#' @param sc existing spark context
286+
#' @param groupid the ID to be assigned to job groups
287+
#' @param description description for the the job group ID
288+
#' @param interruptOnCancel flag to indicate if the job is interrupted on job cancellation
289+
#' @examples
290+
#'\dontrun{
291+
#' sc <- sparkR.init()
292+
#' setJobGroup(sc, "myJobGroup", "My job group description", TRUE)
293+
#'}
294+
295+
setJobGroup <- function(sc, groupId, description, interruptOnCancel) {
296+
callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)
297+
}
298+
299+
#' Clear current job group ID and its description
300+
#'
301+
#' @param sc existing spark context
302+
#' @examples
303+
#'\dontrun{
304+
#' sc <- sparkR.init()
305+
#' clearJobGroup(sc)
306+
#'}
307+
308+
clearJobGroup <- function(sc) {
309+
callJMethod(sc, "clearJobGroup")
310+
}
311+
312+
#' Cancel active jobs for the specified group
313+
#'
314+
#' @param sc existing spark context
315+
#' @param groupId the ID of job group to be cancelled
316+
#' @examples
317+
#'\dontrun{
318+
#' sc <- sparkR.init()
319+
#' cancelJobGroup(sc, "myJobGroup")
320+
#'}
321+
322+
cancelJobGroup <- function(sc, groupId) {
323+
callJMethod(sc, "cancelJobGroup", groupId)
324+
}

R/pkg/inst/tests/test_context.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ test_that("rdd GC across sparkR.stop", {
4848
count(rdd3)
4949
count(rdd4)
5050
})
51+
52+
test_that("job group functions can be called", {
53+
sc <- sparkR.init()
54+
setJobGroup(sc, "groupId", "job description", TRUE)
55+
cancelJobGroup(sc, "groupId")
56+
clearJobGroup(sc)
57+
})

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class TaskMetrics extends Serializable {
9494
*/
9595
private var _diskBytesSpilled: Long = _
9696
def diskBytesSpilled: Long = _diskBytesSpilled
97-
def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
98-
def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
97+
private[spark] def incDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled += value
98+
private[spark] def decDiskBytesSpilled(value: Long): Unit = _diskBytesSpilled -= value
9999

100100
/**
101101
* If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read

core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.serializer
1919

20-
import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
20+
import java.io._
2121
import java.lang.reflect.{Field, Method}
2222
import java.security.AccessController
2323

@@ -62,7 +62,7 @@ private[spark] object SerializationDebugger extends Logging {
6262
*
6363
* It does not yet handle writeObject override, but that shouldn't be too hard to do either.
6464
*/
65-
def find(obj: Any): List[String] = {
65+
private[serializer] def find(obj: Any): List[String] = {
6666
new SerializationDebugger().visit(obj, List.empty)
6767
}
6868

@@ -125,6 +125,12 @@ private[spark] object SerializationDebugger extends Logging {
125125
return List.empty
126126
}
127127

128+
/**
129+
* Visit an externalizable object.
130+
* Since writeExternal() can choose to add arbitrary objects at the time of serialization,
131+
* the only way to capture all the objects it will serialize is by using a
132+
* dummy ObjectOutput that collects all the relevant objects for further testing.
133+
*/
128134
private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
129135
{
130136
val fieldList = new ListObjectOutput
@@ -145,17 +151,50 @@ private[spark] object SerializationDebugger extends Logging {
145151
// An object contains multiple slots in serialization.
146152
// Get the slots and visit fields in all of them.
147153
val (finalObj, desc) = findObjectAndDescriptor(o)
154+
155+
// If the object has been replaced using writeReplace(),
156+
// then call visit() on it again to test its type again.
157+
if (!finalObj.eq(o)) {
158+
return visit(finalObj, s"writeReplace data (class: ${finalObj.getClass.getName})" :: stack)
159+
}
160+
161+
// Every class is associated with one or more "slots", each slot refers to the parent
162+
// classes of this class. These slots are used by the ObjectOutputStream
163+
// serialization code to recursively serialize the fields of an object and
164+
// its parent classes. For example, if there are the following classes.
165+
//
166+
// class ParentClass(parentField: Int)
167+
// class ChildClass(childField: Int) extends ParentClass(1)
168+
//
169+
// Then serializing the an object Obj of type ChildClass requires first serializing the fields
170+
// of ParentClass (that is, parentField), and then serializing the fields of ChildClass
171+
// (that is, childField). Correspondingly, there will be two slots related to this object:
172+
//
173+
// 1. ParentClass slot, which will be used to serialize parentField of Obj
174+
// 2. ChildClass slot, which will be used to serialize childField fields of Obj
175+
//
176+
// The following code uses the description of each slot to find the fields in the
177+
// corresponding object to visit.
178+
//
148179
val slotDescs = desc.getSlotDescs
149180
var i = 0
150181
while (i < slotDescs.length) {
151182
val slotDesc = slotDescs(i)
152183
if (slotDesc.hasWriteObjectMethod) {
153-
// TODO: Handle classes that specify writeObject method.
184+
// If the class type corresponding to current slot has writeObject() defined,
185+
// then its not obvious which fields of the class will be serialized as the writeObject()
186+
// can choose arbitrary fields for serialization. This case is handled separately.
187+
val elem = s"writeObject data (class: ${slotDesc.getName})"
188+
val childStack = visitSerializableWithWriteObjectMethod(finalObj, elem :: stack)
189+
if (childStack.nonEmpty) {
190+
return childStack
191+
}
154192
} else {
193+
// Visit all the fields objects of the class corresponding to the current slot.
155194
val fields: Array[ObjectStreamField] = slotDesc.getFields
156195
val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
157196
val numPrims = fields.length - objFieldValues.length
158-
desc.getObjFieldValues(finalObj, objFieldValues)
197+
slotDesc.getObjFieldValues(finalObj, objFieldValues)
159198

160199
var j = 0
161200
while (j < objFieldValues.length) {
@@ -169,18 +208,54 @@ private[spark] object SerializationDebugger extends Logging {
169208
}
170209
j += 1
171210
}
172-
173211
}
174212
i += 1
175213
}
176214
return List.empty
177215
}
216+
217+
/**
218+
* Visit a serializable object which has the writeObject() defined.
219+
* Since writeObject() can choose to add arbitrary objects at the time of serialization,
220+
* the only way to capture all the objects it will serialize is by using a
221+
* dummy ObjectOutputStream that collects all the relevant fields for further testing.
222+
* This is similar to how externalizable objects are visited.
223+
*/
224+
private def visitSerializableWithWriteObjectMethod(
225+
o: Object, stack: List[String]): List[String] = {
226+
val innerObjectsCatcher = new ListObjectOutputStream
227+
var notSerializableFound = false
228+
try {
229+
innerObjectsCatcher.writeObject(o)
230+
} catch {
231+
case io: IOException =>
232+
notSerializableFound = true
233+
}
234+
235+
// If something was not serializable, then visit the captured objects.
236+
// Otherwise, all the captured objects are safely serializable, so no need to visit them.
237+
// As an optimization, just added them to the visited list.
238+
if (notSerializableFound) {
239+
val innerObjects = innerObjectsCatcher.outputArray
240+
var k = 0
241+
while (k < innerObjects.length) {
242+
val childStack = visit(innerObjects(k), stack)
243+
if (childStack.nonEmpty) {
244+
return childStack
245+
}
246+
k += 1
247+
}
248+
} else {
249+
visited ++= innerObjectsCatcher.outputArray
250+
}
251+
return List.empty
252+
}
178253
}
179254

180255
/**
181256
* Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
182257
* writeReplace in Serializable. It starts with the object itself, and keeps calling the
183-
* writeReplace method until there is no more
258+
* writeReplace method until there is no more.
184259
*/
185260
@tailrec
186261
private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
@@ -220,6 +295,31 @@ private[spark] object SerializationDebugger extends Logging {
220295
override def writeByte(i: Int): Unit = {}
221296
}
222297

298+
/** An output stream that emulates /dev/null */
299+
private class NullOutputStream extends OutputStream {
300+
override def write(b: Int) { }
301+
}
302+
303+
/**
304+
* A dummy [[ObjectOutputStream]] that saves the list of objects written to it and returns
305+
* them through `outputArray`. This works by using the [[ObjectOutputStream]]'s `replaceObject()`
306+
* method which gets called on every object, only if replacing is enabled. So this subclass
307+
* of [[ObjectOutputStream]] enabled replacing, and uses replaceObject to get the objects that
308+
* are being serializabled. The serialized bytes are ignored by sending them to a
309+
* [[NullOutputStream]], which acts like a /dev/null.
310+
*/
311+
private class ListObjectOutputStream extends ObjectOutputStream(new NullOutputStream) {
312+
private val output = new mutable.ArrayBuffer[Any]
313+
this.enableReplaceObject(true)
314+
315+
def outputArray: Array[Any] = output.toArray
316+
317+
override def replaceObject(obj: Object): Object = {
318+
output += obj
319+
obj
320+
}
321+
}
322+
223323
/** An implicit class that allows us to call private methods of ObjectStreamClass. */
224324
implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
225325
def getSlotDescs: Array[ObjectStreamClass] = {

core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi
3232
* size, which is guaranteed to explore all spaces for each key (see
3333
* http://en.wikipedia.org/wiki/Quadratic_probing).
3434
*
35-
* The map can support up to `536870912 (2 ^ 29)` elements.
35+
* The map can support up to `375809638 (0.7 * 2 ^ 29)` elements.
3636
*
3737
* TODO: Cache the hash values of each key? java.util.HashMap does that.
3838
*/
@@ -199,11 +199,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
199199

200200
/** Increase table size by 1, rehashing if necessary */
201201
private def incrementSize() {
202-
if (curSize == MAXIMUM_CAPACITY) {
203-
throw new IllegalStateException(s"Can't put more that ${MAXIMUM_CAPACITY} elements")
204-
}
205202
curSize += 1
206-
if (curSize > growThreshold && capacity < MAXIMUM_CAPACITY) {
203+
if (curSize > growThreshold) {
207204
growTable()
208205
}
209206
}
@@ -216,7 +213,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
216213
/** Double the table's size and re-hash everything */
217214
protected def growTable() {
218215
// capacity < MAXIMUM_CAPACITY (2 ^ 29) so capacity * 2 won't overflow
219-
val newCapacity = (capacity * 2).min(MAXIMUM_CAPACITY)
216+
val newCapacity = capacity * 2
217+
require(newCapacity <= MAXIMUM_CAPACITY, s"Can't contain more than ${growThreshold} elements")
220218
val newData = new Array[AnyRef](2 * newCapacity)
221219
val newMask = newCapacity - 1
222220
// Insert all our old values into the new array. Note that because our old keys are

core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ class SparkSubmitSuite
325325
runSparkSubmit(args)
326326
}
327327

328-
ignore("includes jars passed in through --jars") {
328+
test("includes jars passed in through --jars") {
329329
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
330330
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
331331
val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
@@ -340,7 +340,7 @@ class SparkSubmitSuite
340340
}
341341

342342
// SPARK-7287
343-
ignore("includes jars passed in through --packages") {
343+
test("includes jars passed in through --packages") {
344344
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
345345
val main = MavenCoordinate("my.great.lib", "mylib", "0.1")
346346
val dep = MavenCoordinate("my.great.dep", "mylib", "0.1")
@@ -499,9 +499,16 @@ class SparkSubmitSuite
499499
Seq("./bin/spark-submit") ++ args,
500500
new File(sparkHome),
501501
Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome))
502-
failAfter(60 seconds) { process.waitFor() }
503-
// Ensure we still kill the process in case it timed out
504-
process.destroy()
502+
503+
try {
504+
val exitCode = failAfter(60 seconds) { process.waitFor() }
505+
if (exitCode != 0) {
506+
fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.")
507+
}
508+
} finally {
509+
// Ensure we still kill the process in case it timed out
510+
process.destroy()
511+
}
505512
}
506513

507514
private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = {

0 commit comments

Comments
 (0)