Skip to content

Commit 9da347f

Browse files
committed
Merge pull request #2 from apache/master
merge upstream changes
2 parents aa5b4b5 + 8e7d5ba commit 9da347f

File tree

53 files changed

+1976
-264
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1976
-264
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import java.io.File
21+
import java.net.Socket
2122

2223
import scala.collection.JavaConversions._
2324
import scala.collection.mutable
@@ -102,10 +103,10 @@ class SparkEnv (
102103
}
103104

104105
private[spark]
105-
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
106+
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
106107
synchronized {
107108
val key = (pythonExec, envVars)
108-
pythonWorkers(key).stop()
109+
pythonWorkers.get(key).foreach(_.stopWorker(worker))
109110
}
110111
}
111112
}

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
6262
val env = SparkEnv.get
6363
val localdir = env.blockManager.diskBlockManager.localDirs.map(
6464
f => f.getPath()).mkString(",")
65-
val worker: Socket = env.createPythonWorker(pythonExec,
66-
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
65+
envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
66+
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6767

6868
// Start a thread to feed the process input from our parent's iterator
6969
val writerThread = new WriterThread(env, worker, split, context)
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
241241
if (!context.completed) {
242242
try {
243243
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
244-
env.destroyPythonWorker(pythonExec, envVars.toMap)
244+
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
245245
} catch {
246246
case e: Exception =>
247247
logError("Exception when trying to kill worker", e)
@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
685685

686686
/**
687687
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
688-
* This function is outdated, PySpark does not use it anymore
689688
*/
690-
@deprecated
689+
@deprecated("PySpark does not use it anymore", "1.1")
691690
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
692691
pyRDD.rdd.mapPartitions { iter =>
693692
val unpickle = new Unpickler

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.api.python
1919

20-
import java.io.{DataInputStream, InputStream, OutputStreamWriter}
20+
import java.lang.Runtime
21+
import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
2122
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
2223

24+
import scala.collection.mutable
2325
import scala.collection.JavaConversions._
2426

2527
import org.apache.spark._
@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
3941
var daemon: Process = null
4042
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
4143
var daemonPort: Int = 0
44+
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
45+
46+
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
4247

4348
val pythonPath = PythonUtils.mergePythonPaths(
4449
PythonUtils.sparkPythonPath,
@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
5863
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
5964
*/
6065
private def createThroughDaemon(): Socket = {
66+
67+
def createSocket(): Socket = {
68+
val socket = new Socket(daemonHost, daemonPort)
69+
val pid = new DataInputStream(socket.getInputStream).readInt()
70+
if (pid < 0) {
71+
throw new IllegalStateException("Python daemon failed to launch worker")
72+
}
73+
daemonWorkers.put(socket, pid)
74+
socket
75+
}
76+
6177
synchronized {
6278
// Start the daemon if it hasn't been started
6379
startDaemon()
6480

6581
// Attempt to connect, restart and retry once if it fails
6682
try {
67-
val socket = new Socket(daemonHost, daemonPort)
68-
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
69-
if (launchStatus != 0) {
70-
throw new IllegalStateException("Python daemon failed to launch worker")
71-
}
72-
socket
83+
createSocket()
7384
} catch {
7485
case exc: SocketException =>
7586
logWarning("Failed to open socket to Python daemon:", exc)
7687
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
7788
stopDaemon()
7889
startDaemon()
79-
new Socket(daemonHost, daemonPort)
90+
createSocket()
8091
}
8192
}
8293
}
@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
107118
// Wait for it to connect to our socket
108119
serverSocket.setSoTimeout(10000)
109120
try {
110-
return serverSocket.accept()
121+
val socket = serverSocket.accept()
122+
simpleWorkers.put(socket, worker)
123+
return socket
111124
} catch {
112125
case e: Exception =>
113126
throw new SparkException("Python worker did not connect back in time", e)
@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
189202

190203
private def stopDaemon() {
191204
synchronized {
192-
// Request shutdown of existing daemon by sending SIGTERM
193-
if (daemon != null) {
194-
daemon.destroy()
195-
}
205+
if (useDaemon) {
206+
// Request shutdown of existing daemon by sending SIGTERM
207+
if (daemon != null) {
208+
daemon.destroy()
209+
}
196210

197-
daemon = null
198-
daemonPort = 0
211+
daemon = null
212+
daemonPort = 0
213+
} else {
214+
simpleWorkers.mapValues(_.destroy())
215+
}
199216
}
200217
}
201218

202219
def stop() {
203220
stopDaemon()
204221
}
222+
223+
def stopWorker(worker: Socket) {
224+
if (useDaemon) {
225+
if (daemon != null) {
226+
daemonWorkers.get(worker).foreach { pid =>
227+
// tell daemon to kill worker by pid
228+
val output = new DataOutputStream(daemon.getOutputStream)
229+
output.writeInt(pid)
230+
output.flush()
231+
daemon.getOutputStream.flush()
232+
}
233+
}
234+
} else {
235+
simpleWorkers.get(worker).foreach(_.destroy())
236+
}
237+
worker.close()
238+
}
205239
}
206240

207241
private object PythonWorkerFactory {

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,15 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
3535
/**
3636
* Calling reset to avoid memory leak:
3737
* http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
38-
* But only call it every 10,000th time to avoid bloated serialization streams (when
38+
* But only call it every 100th time to avoid bloated serialization streams (when
3939
* the stream 'resets' object class descriptions have to be re-written)
4040
*/
4141
def writeObject[T: ClassTag](t: T): SerializationStream = {
4242
objOut.writeObject(t)
43+
counter += 1
4344
if (counterReset > 0 && counter >= counterReset) {
4445
objOut.reset()
4546
counter = 0
46-
} else {
47-
counter += 1
4847
}
4948
this
5049
}

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 65 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.util.collection
1919

20-
import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException}
20+
import java.io._
2121
import java.util.Comparator
2222

2323
import scala.collection.BufferedIterator
@@ -28,7 +28,7 @@ import com.google.common.io.ByteStreams
2828

2929
import org.apache.spark.{Logging, SparkEnv}
3030
import org.apache.spark.annotation.DeveloperApi
31-
import org.apache.spark.serializer.Serializer
31+
import org.apache.spark.serializer.{DeserializationStream, Serializer}
3232
import org.apache.spark.storage.{BlockId, BlockManager}
3333
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
3434

@@ -199,13 +199,16 @@ class ExternalAppendOnlyMap[K, V, C](
199199

200200
// Flush the disk writer's contents to disk, and update relevant variables
201201
def flush() = {
202-
writer.commitAndClose()
203-
val bytesWritten = writer.bytesWritten
202+
val w = writer
203+
writer = null
204+
w.commitAndClose()
205+
val bytesWritten = w.bytesWritten
204206
batchSizes.append(bytesWritten)
205207
_diskBytesSpilled += bytesWritten
206208
objectsWritten = 0
207209
}
208210

211+
var success = false
209212
try {
210213
val it = currentMap.destructiveSortedIterator(keyComparator)
211214
while (it.hasNext) {
@@ -215,16 +218,28 @@ class ExternalAppendOnlyMap[K, V, C](
215218

216219
if (objectsWritten == serializerBatchSize) {
217220
flush()
218-
writer.close()
219221
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
220222
}
221223
}
222224
if (objectsWritten > 0) {
223225
flush()
226+
} else if (writer != null) {
227+
val w = writer
228+
writer = null
229+
w.revertPartialWritesAndClose()
224230
}
231+
success = true
225232
} finally {
226-
// Partial failures cannot be tolerated; do not revert partial writes
227-
writer.close()
233+
if (!success) {
234+
// This code path only happens if an exception was thrown above before we set success;
235+
// close our stuff and let the exception be thrown further
236+
if (writer != null) {
237+
writer.revertPartialWritesAndClose()
238+
}
239+
if (file.exists()) {
240+
file.delete()
241+
}
242+
}
228243
}
229244

230245
currentMap = new SizeTrackingAppendOnlyMap[K, C]
@@ -389,27 +404,51 @@ class ExternalAppendOnlyMap[K, V, C](
389404
* An iterator that returns (K, C) pairs in sorted order from an on-disk map
390405
*/
391406
private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long])
392-
extends Iterator[(K, C)] {
393-
private val fileStream = new FileInputStream(file)
394-
private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
407+
extends Iterator[(K, C)]
408+
{
409+
private val batchOffsets = batchSizes.scanLeft(0L)(_ + _) // Size will be batchSize.length + 1
410+
assert(file.length() == batchOffsets(batchOffsets.length - 1))
411+
412+
private var batchIndex = 0 // Which batch we're in
413+
private var fileStream: FileInputStream = null
395414

396415
// An intermediate stream that reads from exactly one batch
397416
// This guards against pre-fetching and other arbitrary behavior of higher level streams
398-
private var batchStream = nextBatchStream()
399-
private var compressedStream = blockManager.wrapForCompression(blockId, batchStream)
400-
private var deserializeStream = ser.deserializeStream(compressedStream)
417+
private var deserializeStream = nextBatchStream()
401418
private var nextItem: (K, C) = null
402419
private var objectsRead = 0
403420

404421
/**
405422
* Construct a stream that reads only from the next batch.
406423
*/
407-
private def nextBatchStream(): InputStream = {
408-
if (batchSizes.length > 0) {
409-
ByteStreams.limit(bufferedStream, batchSizes.remove(0))
424+
private def nextBatchStream(): DeserializationStream = {
425+
// Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether
426+
// we're still in a valid batch.
427+
if (batchIndex < batchOffsets.length - 1) {
428+
if (deserializeStream != null) {
429+
deserializeStream.close()
430+
fileStream.close()
431+
deserializeStream = null
432+
fileStream = null
433+
}
434+
435+
val start = batchOffsets(batchIndex)
436+
fileStream = new FileInputStream(file)
437+
fileStream.getChannel.position(start)
438+
batchIndex += 1
439+
440+
val end = batchOffsets(batchIndex)
441+
442+
assert(end >= start, "start = " + start + ", end = " + end +
443+
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
444+
445+
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
446+
val compressedStream = blockManager.wrapForCompression(blockId, bufferedStream)
447+
ser.deserializeStream(compressedStream)
410448
} else {
411449
// No more batches left
412-
bufferedStream
450+
cleanup()
451+
null
413452
}
414453
}
415454

@@ -424,10 +463,8 @@ class ExternalAppendOnlyMap[K, V, C](
424463
val item = deserializeStream.readObject().asInstanceOf[(K, C)]
425464
objectsRead += 1
426465
if (objectsRead == serializerBatchSize) {
427-
batchStream = nextBatchStream()
428-
compressedStream = blockManager.wrapForCompression(blockId, batchStream)
429-
deserializeStream = ser.deserializeStream(compressedStream)
430466
objectsRead = 0
467+
deserializeStream = nextBatchStream()
431468
}
432469
item
433470
} catch {
@@ -439,6 +476,9 @@ class ExternalAppendOnlyMap[K, V, C](
439476

440477
override def hasNext: Boolean = {
441478
if (nextItem == null) {
479+
if (deserializeStream == null) {
480+
return false
481+
}
442482
nextItem = readNextItem()
443483
}
444484
nextItem != null
@@ -455,7 +495,11 @@ class ExternalAppendOnlyMap[K, V, C](
455495

456496
// TODO: Ensure this gets called even if the iterator isn't drained.
457497
private def cleanup() {
458-
deserializeStream.close()
498+
batchIndex = batchOffsets.length // Prevent reading any other batch
499+
val ds = deserializeStream
500+
deserializeStream = null
501+
fileStream = null
502+
ds.close()
459503
file.delete()
460504
}
461505
}

0 commit comments

Comments
 (0)