Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {

// Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
test("parameter accuracy") {

// create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0, 0.0))
Expand Down Expand Up @@ -82,7 +81,6 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {

// Test that parameter estimates improve when learning Y = 10*X1 on streaming data
test("parameter convergence") {

// create model
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(0.0))
Expand Down Expand Up @@ -113,12 +111,10 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
// check that error shrunk on at least 2 batches
assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1)

}

// Test predictions on a stream
test("predictions") {

// create model initialized with true weights
val model = new StreamingLinearRegressionWithSGD()
.setInitialWeights(Vectors.dense(10.0, 10.0))
Expand All @@ -142,7 +138,5 @@ class StreamingLinearRegressionSuite extends FunSuite with TestSuiteBase {
// compute the mean absolute error and check that it's always less than 0.1
val errors = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints)
assert(errors.forall(x => x <= 0.1))

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@

package org.apache.spark.streaming

import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.util.ManualClock
import java.io.{ObjectInputStream, IOException}

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.SynchronizedBuffer
import scala.reflect.ClassTag

import java.io.{ObjectInputStream, IOException}

import org.scalatest.{BeforeAndAfter, FunSuite}
import com.google.common.io.Files

import org.apache.spark.{SparkContext, SparkConf, Logging}
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.rdd.RDD

/**
Expand Down Expand Up @@ -119,7 +119,12 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
def batchDuration = Seconds(1)

// Directory where the checkpoint data will be saved
def checkpointDir = "checkpoint"
lazy val checkpointDir = {
val dir = Files.createTempDir()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, delete this explicitly and recursively in the 'after' function. It could fail to be deleted if the JVM crashes later on. Minor though.

logDebug(s"checkpointDir: $dir")
dir.deleteOnExit()
dir.toString
}

// Number of partitions of the input parallel collections created for testing
def numInputPartitions = 2
Expand Down