Skip to content

Commit

Permalink
[WIP] SPARK-1430: Support sparse data in Python MLlib
Browse files Browse the repository at this point in the history
This PR adds a SparseVector class in PySpark and updates all the regression, classification and clustering algorithms and models to support sparse data, similar to MLlib. I chose to add this class because SciPy is quite difficult to install in many environments (more so than NumPy), but I plan to add support for SciPy sparse vectors later too, and make the methods work transparently on objects of either type.

On the Scala side, we keep Python sparse vectors sparse and pass them to MLlib. We always return dense vectors from our models.

Some to-do items left:
- [x] Support SciPy's scipy.sparse matrix objects when SciPy is available. We can easily add a function to convert these to our own SparseVector.
- [x] MLlib currently uses a vector with one extra column on the left to represent what we call LabeledPoint in Scala. Do we really want this? It may get annoying once you deal with sparse data since you must add/subtract 1 to each feature index when training. We can remove this API in 1.0 and use tuples for labeling.
- [x] Explain how to use these in the Python MLlib docs.

CC @mengxr, @JoshRosen

Author: Matei Zaharia <matei@databricks.com>

Closes apache#341 from mateiz/py-ml-update and squashes the following commits:

d52e763 [Matei Zaharia] Remove no-longer-needed slice code and handle review comments
ea5a25a [Matei Zaharia] Fix remaining uses of copyto() after merge
b9f97a3 [Matei Zaharia] Fix test
1e1bd0f [Matei Zaharia] Add MLlib logistic regression example in Python
88bc01f [Matei Zaharia] Clean up inheritance of LinearModel in Python, and expose its parametrs
37ab747 [Matei Zaharia] Fix some examples and docs due to changes in MLlib API
da0f27e [Matei Zaharia] Added a MLlib K-means example and updated docs to discuss sparse data
c48e85a [Matei Zaharia] Added some tests for passing lists as input, and added mllib/tests.py to run-tests script.
a07ba10 [Matei Zaharia] Fix some typos and calculation of initial weights
74eefe7 [Matei Zaharia] Added LabeledPoint class in Python
889dde8 [Matei Zaharia] Support scipy.sparse matrices in all our algorithms and models
ab244d1 [Matei Zaharia] Allow SparseVectors to be initialized using a dict
a5d6426 [Matei Zaharia] Add linalg.py to run-tests script
0e7a3d8 [Matei Zaharia] Keep vectors sparse in Java when reading LabeledPoints
eaee759 [Matei Zaharia] Update regression, classification and clustering models for sparse data
2abbb44 [Matei Zaharia] Further work to get linear models working with sparse data
154f45d [Matei Zaharia] Update docs, name some magic values
881fef7 [Matei Zaharia] Added a sparse vector in Python and made Java-Python format more compact
  • Loading branch information
mateiz authored and pwendell committed Apr 16, 2014
1 parent 8517911 commit 63ca581
Show file tree
Hide file tree
Showing 18 changed files with 1,368 additions and 214 deletions.
45 changes: 27 additions & 18 deletions docs/mllib-classification-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,17 @@ error.
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data file
val data = sc.textFile("mllib/data/sample_svm_data.txt")
val parsedData = data.map { line =>
val parts = line.split(' ')
LabeledPoint(parts(0).toDouble, parts.tail.map(x => x.toDouble).toArray)
val parts = line.split(' ').map(_.toDouble)
LabeledPoint(parts(0), Vectors.dense(parts.tail))
}

// Run training algorithm to build the model
val numIterations = 20
val numIterations = 100
val model = SVMWithSGD.train(parsedData, numIterations)

// Evaluate model on training examples and compute training error
Expand Down Expand Up @@ -401,29 +402,30 @@ val modelL1 = svmAlg.run(parsedData)
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the Mean Squared Error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit)
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).

{% highlight scala %}
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data
val data = sc.textFile("mllib/data/ridge-data/lpsa.data")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(0).toDouble, parts(1).split(' ').map(x => x.toDouble).toArray)
LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
}

// Building the model
val numIterations = 20
val numIterations = 100
val model = LinearRegressionWithSGD.train(parsedData, numIterations)

// Evaluate model on training examples and compute training error
val valuesAndPreds = parsedData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2)}.reduce(_ + _)/valuesAndPreds.count
val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.reduce(_ + _) / valuesAndPreds.count
println("training Mean Squared Error = " + MSE)
{% endhighlight %}

Expand Down Expand Up @@ -518,18 +520,22 @@ and make predictions with the resulting model to compute the training error.

{% highlight python %}
from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
from numpy import array

# Load and parse the data
def parsePoint(line):
values = [float(x) for x in line.split(' ')]
return LabeledPoint(values[0], values[1:])

data = sc.textFile("mllib/data/sample_svm_data.txt")
parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
model = LogisticRegressionWithSGD.train(parsedData)
parsedData = data.map(parsePoint)

# Build the model
labelsAndPreds = parsedData.map(lambda point: (int(point.item(0)),
model.predict(point.take(range(1, point.size)))))
model = LogisticRegressionWithSGD.train(parsedData)

# Evaluating the model on training data
labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
print("Training Error = " + str(trainErr))
{% endhighlight %}
Expand All @@ -538,22 +544,25 @@ print("Training Error = " + str(trainErr))
The following example demonstrate how to load training data, parse it as an RDD of LabeledPoint.
The example then uses LinearRegressionWithSGD to build a simple linear model to predict label
values. We compute the Mean Squared Error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit)
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).

{% highlight python %}
from pyspark.mllib.regression import LinearRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD
from numpy import array

# Load and parse the data
def parsePoint(line):
values = [float(x) for x in line.replace(',', ' ').split(' ')]
return LabeledPoint(values[0], values[1:])

data = sc.textFile("mllib/data/ridge-data/lpsa.data")
parsedData = data.map(lambda line: array([float(x) for x in line.replace(',', ' ').split(' ')]))
parsedData = data.map(parsePoint)

# Build the model
model = LinearRegressionWithSGD.train(parsedData)

# Evaluate the model on training data
valuesAndPreds = parsedData.map(lambda point: (point.item(0),
model.predict(point.take(range(1, point.size)))))
MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y)/valuesAndPreds.count()
valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count()
print("Mean Squared Error = " + str(MSE))
{% endhighlight %}
{% endhighlight %}
11 changes: 6 additions & 5 deletions docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ optimal *k* is usually one where there is an "elbow" in the WSSSE graph.

{% highlight scala %}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data
val data = sc.textFile("kmeans_data.txt")
val parsedData = data.map( _.split(' ').map(_.toDouble))
val data = sc.textFile("data/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble)))

// Cluster the data into two classes using KMeans
val numIterations = 20
val numClusters = 2
val numIterations = 20
val clusters = KMeans.train(parsedData, numClusters, numIterations)

// Evaluate clustering by computing Within Set Sum of Squared Errors
Expand Down Expand Up @@ -85,12 +86,12 @@ from numpy import array
from math import sqrt

# Load and parse the data
data = sc.textFile("kmeans_data.txt")
data = sc.textFile("data/kmeans_data.txt")
parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))

# Build the model (cluster the data)
clusters = KMeans.train(parsedData, 2, maxIterations=10,
runs=30, initialization_mode="random")
runs=10, initialization_mode="random")

# Evaluate clustering by computing Within Set Sum of Squared Errors
def error(point):
Expand Down
27 changes: 25 additions & 2 deletions docs/mllib-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ title: Machine Learning Library (MLlib)
MLlib is a Spark implementation of some common machine learning (ML)
functionality, as well associated tests and data generators. MLlib
currently supports four common types of machine learning problem settings,
namely, binary classification, regression, clustering and collaborative
filtering, as well as an underlying gradient descent optimization primitive.
namely classification, regression, clustering and collaborative filtering,
as well as an underlying gradient descent optimization primitive and several
linear algebra methods.

# Available Methods
The following links provide a detailed explanation of the methods and usage examples for each of them:
Expand All @@ -32,6 +33,28 @@ The following links provide a detailed explanation of the methods and usage exam
* Singular Value Decomposition
* Principal Component Analysis

# Data Types

Most MLlib algorithms operate on RDDs containing vectors. In Java and Scala, the
[Vector](api/mllib/index.html#org.apache.spark.mllib.linalg.Vector) class is used to
represent vectors. You can create either dense or sparse vectors using the
[Vectors](api/mllib/index.html#org.apache.spark.mllib.linalg.Vectors$) factory.

In Python, MLlib can take the following vector types:

* [NumPy](http://www.numpy.org) arrays
* Standard Python lists (e.g. `[1, 2, 3]`)
* The MLlib [SparseVector](api/pyspark/pyspark.mllib.linalg.SparseVector-class.html) class
* [SciPy sparse matrices](http://docs.scipy.org/doc/scipy/reference/sparse.html)

For efficiency, we recommend using NumPy arrays over lists, and using the
[CSC format](http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html#scipy.sparse.csc_matrix)
for SciPy matrices, or MLlib's own SparseVector class.

Several other simple data types are used throughout the library, e.g. the LabeledPoint
class ([Java/Scala](api/mllib/index.html#org.apache.spark.mllib.regression.LabeledPoint),
[Python](api/pyspark/pyspark.mllib.regression.LabeledPoint-class.html)) for labeled data.

# Dependencies
MLlib uses the [jblas](https://github.com/mikiobraun/jblas) linear algebra library, which itself
depends on native Fortran routines. You may need to install the
Expand Down
Loading

0 comments on commit 63ca581

Please sign in to comment.