22
22
from pyspark .rdd import RDD
23
23
from pyspark .mllib .common import JavaModelWrapper , callMLlibFunc , inherit_doc
24
24
from pyspark .mllib .util import JavaLoader , JavaSaveable
25
+ from pyspark .sql import DataFrame
25
26
26
27
__all__ = ['MatrixFactorizationModel' , 'ALS' , 'Rating' ]
27
28
@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
78
79
True
79
80
80
81
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
81
- >>> model.predict(2,2)
82
+ >>> model.predict(2, 2)
83
+ 3.8...
84
+
85
+ >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
86
+ >>> model = ALS.train(df, 1, nonnegative=True, seed=10)
87
+ >>> model.predict(2, 2)
82
88
3.8...
83
89
84
90
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
85
- >>> model.predict(2,2)
91
+ >>> model.predict(2, 2)
86
92
0.4...
87
93
88
94
>>> import os, tempfile
89
95
>>> path = tempfile.mkdtemp()
90
96
>>> model.save(sc, path)
91
97
>>> sameModel = MatrixFactorizationModel.load(sc, path)
92
- >>> sameModel.predict(2,2)
98
+ >>> sameModel.predict(2, 2)
93
99
0.4...
94
100
>>> sameModel.predictAll(testset).collect()
95
101
[Rating(...
@@ -125,13 +131,20 @@ class ALS(object):
125
131
126
132
@classmethod
127
133
def _prepare (cls , ratings ):
128
- assert isinstance (ratings , RDD ), "ratings should be RDD"
134
+ if isinstance (ratings , RDD ):
135
+ pass
136
+ elif isinstance (ratings , DataFrame ):
137
+ ratings = ratings .rdd
138
+ else :
139
+ raise TypeError ("Ratings should be represented by either an RDD or a DataFrame, "
140
+ "but got %s." % type (ratings ))
129
141
first = ratings .first ()
130
- if not isinstance (first , Rating ):
131
- if isinstance (first , (tuple , list )):
132
- ratings = ratings .map (lambda x : Rating (* x ))
133
- else :
134
- raise ValueError ("rating should be RDD of Rating or tuple/list" )
142
+ if isinstance (first , Rating ):
143
+ pass
144
+ elif isinstance (first , (tuple , list )):
145
+ ratings = ratings .map (lambda x : Rating (* x ))
146
+ else :
147
+ raise TypeError ("Expect a Rating or a tuple/list, but got %s." % type (first ))
135
148
return ratings
136
149
137
150
@classmethod
@@ -152,8 +165,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp
152
165
def _test ():
153
166
import doctest
154
167
import pyspark .mllib .recommendation
168
+ from pyspark .sql import SQLContext
155
169
globs = pyspark .mllib .recommendation .__dict__ .copy ()
156
- globs ['sc' ] = SparkContext ('local[4]' , 'PythonTest' )
170
+ sc = SparkContext ('local[4]' , 'PythonTest' )
171
+ globs ['sc' ] = sc
172
+ globs ['sqlContext' ] = SQLContext (sc )
157
173
(failure_count , test_count ) = doctest .testmod (globs = globs , optionflags = doctest .ELLIPSIS )
158
174
globs ['sc' ].stop ()
159
175
if failure_count :
0 commit comments