Skip to content

Commit 89964ab

Browse files
committed
use temp folder for checkpoint
1 parent 825d4fe commit 89964ab

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
4949

5050
// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
5151
test("parameter accuracy") {
52-
5352
// create model
5453
val model = new StreamingLinearRegressionWithSGD()
5554
.setInitialWeights(Vectors.dense(0.0, 0.0))
@@ -78,11 +77,12 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
7877
val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
7978
validatePrediction(validationData.map(row => model.latestModel().predict(row.features)),
8079
validationData)
80+
81+
ssc.stop()
8182
}
8283

8384
// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
8485
test("parameter convergence") {
85-
8686
// create model
8787
val model = new StreamingLinearRegressionWithSGD()
8888
.setInitialWeights(Vectors.dense(0.0))
@@ -114,11 +114,11 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
114114
// check that error shrunk on at least 2 batches
115115
assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1)
116116

117+
ssc.stop()
117118
}
118119

119120
// Test predictions on a stream
120121
test("predictions") {
121-
122122
// create model initialized with true weights
123123
val model = new StreamingLinearRegressionWithSGD()
124124
.setInitialWeights(Vectors.dense(10.0, 10.0))
@@ -143,6 +143,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
143143
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
144144
assert(errors.forall(x => x <= 0.1))
145145

146+
ssc.stop()
146147
}
147-
148148
}

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@
1717

1818
package org.apache.spark.streaming
1919

20-
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
21-
import org.apache.spark.streaming.util.ManualClock
20+
import java.io.{ObjectInputStream, IOException}
2221

2322
import scala.collection.mutable.ArrayBuffer
2423
import scala.collection.mutable.SynchronizedBuffer
2524
import scala.reflect.ClassTag
2625

27-
import java.io.{ObjectInputStream, IOException}
28-
2926
import org.scalatest.{BeforeAndAfter, FunSuite}
27+
import com.google.common.io.Files
3028

31-
import org.apache.spark.{SparkContext, SparkConf, Logging}
29+
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
30+
import org.apache.spark.streaming.util.ManualClock
31+
import org.apache.spark.{SparkConf, Logging}
3232
import org.apache.spark.rdd.RDD
3333

3434
/**
@@ -119,7 +119,12 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
119119
def batchDuration = Seconds(1)
120120

121121
// Directory where the checkpoint data will be saved
122-
def checkpointDir = "checkpoint"
122+
def checkpointDir = {
123+
val dir = Files.createTempDir()
124+
logDebug(s"checkpointDir: $dir")
125+
dir.deleteOnExit()
126+
dir.toString
127+
}
123128

124129
// Number of partitions of the input parallel collections created for testing
125130
def numInputPartitions = 2

0 commit comments

Comments
 (0)