Skip to content

Commit 686dd74

Browse files
committed
[SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark
SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #5619 from mengxr/SPARK-7036 and squashes the following commits: dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark
1 parent 7fe6142 commit 686dd74

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

python/pyspark/mllib/recommendation.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pyspark.rdd import RDD
2323
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
2424
from pyspark.mllib.util import JavaLoader, JavaSaveable
25+
from pyspark.sql import DataFrame
2526

2627
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
2728

@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
7879
True
7980
8081
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
81-
>>> model.predict(2,2)
82+
>>> model.predict(2, 2)
83+
3.8...
84+
85+
>>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
86+
>>> model = ALS.train(df, 1, nonnegative=True, seed=10)
87+
>>> model.predict(2, 2)
8288
3.8...
8389
8490
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
85-
>>> model.predict(2,2)
91+
>>> model.predict(2, 2)
8692
0.4...
8793
8894
>>> import os, tempfile
8995
>>> path = tempfile.mkdtemp()
9096
>>> model.save(sc, path)
9197
>>> sameModel = MatrixFactorizationModel.load(sc, path)
92-
>>> sameModel.predict(2,2)
98+
>>> sameModel.predict(2, 2)
9399
0.4...
94100
>>> sameModel.predictAll(testset).collect()
95101
[Rating(...
@@ -125,13 +131,20 @@ class ALS(object):
125131

126132
@classmethod
127133
def _prepare(cls, ratings):
128-
assert isinstance(ratings, RDD), "ratings should be RDD"
134+
if isinstance(ratings, RDD):
135+
pass
136+
elif isinstance(ratings, DataFrame):
137+
ratings = ratings.rdd
138+
else:
139+
raise TypeError("Ratings should be represented by either an RDD or a DataFrame, "
140+
"but got %s." % type(ratings))
129141
first = ratings.first()
130-
if not isinstance(first, Rating):
131-
if isinstance(first, (tuple, list)):
132-
ratings = ratings.map(lambda x: Rating(*x))
133-
else:
134-
raise ValueError("rating should be RDD of Rating or tuple/list")
142+
if isinstance(first, Rating):
143+
pass
144+
elif isinstance(first, (tuple, list)):
145+
ratings = ratings.map(lambda x: Rating(*x))
146+
else:
147+
raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first))
135148
return ratings
136149

137150
@classmethod
@@ -152,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp
152165
def _test():
153166
import doctest
154167
import pyspark.mllib.recommendation
168+
from pyspark.sql import SQLContext
155169
globs = pyspark.mllib.recommendation.__dict__.copy()
156-
globs['sc'] = SparkContext('local[4]', 'PythonTest')
170+
sc = SparkContext('local[4]', 'PythonTest')
171+
globs['sc'] = sc
172+
globs['sqlContext'] = SQLContext(sc)
157173
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
158174
globs['sc'].stop()
159175
if failure_count:

0 commit comments

Comments
 (0)