@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
6363 indicated user preferences rather than explicit ratings given to
6464 items.
6565
66+ >>> df = sqlContext.createDataFrame(
67+ ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
68+ ... ["user", "item", "rating"])
6669 >>> als = ALS(rank=10, maxIter=5)
6770 >>> model = als.fit(df)
71+ >>> model.rank
72+ 10
73+ >>> model.userFactors.orderBy("id").collect()
74+ [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
6875 >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
6976 >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
7077 >>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
260267 Model fitted by ALS.
261268 """
262269
270+ @property
271+ def rank (self ):
272+ """rank of the matrix factorization model"""
273+ return self ._call_java ("rank" )
274+
275+ @property
276+ def userFactors (self ):
277+ """
278+ a DataFrame that stores user factors in two columns: `id` and
279+ `features`
280+ """
281+ return self ._call_java ("userFactors" )
282+
283+ @property
284+ def itemFactors (self ):
285+ """
286+ a DataFrame that stores item factors in two columns: `id` and
287+ `features`
288+ """
289+ return self ._call_java ("itemFactors" )
290+
263291
264292if __name__ == "__main__" :
265293 import doctest
@@ -272,8 +300,6 @@ class ALSModel(JavaModel):
272300 sqlContext = SQLContext (sc )
273301 globs ['sc' ] = sc
274302 globs ['sqlContext' ] = sqlContext
275- globs ['df' ] = sqlContext .createDataFrame ([(0 , 0 , 4.0 ), (0 , 1 , 2.0 ), (1 , 1 , 3.0 ), (1 , 2 , 4.0 ),
276- (2 , 1 , 1.0 ), (2 , 2 , 5.0 )], ["user" , "item" , "rating" ])
277303 (failure_count , test_count ) = doctest .testmod (globs = globs , optionflags = doctest .ELLIPSIS )
278304 sc .stop ()
279305 if failure_count :
0 commit comments