Skip to content

Commit 7bfb1d5

Browse files
committed
update ALSModel in PySpark
1 parent 1ba5607 commit 7bfb1d5

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

python/pyspark/ml/recommendation.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
6363
indicated user preferences rather than explicit ratings given to
6464
items.
6565
66+
>>> df = sqlContext.createDataFrame(
67+
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
68+
... ["user", "item", "rating"])
6669
>>> als = ALS(rank=10, maxIter=5)
6770
>>> model = als.fit(df)
71+
>>> model.rank
72+
10
73+
>>> model.userFactors.orderBy("id").collect()
74+
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
6875
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
6976
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
7077
>>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
260267
Model fitted by ALS.
261268
"""
262269

270+
@property
271+
def rank(self):
272+
"""rank of the matrix factorization model"""
273+
return self._call_java("rank")
274+
275+
@property
276+
def userFactors(self):
277+
"""
278+
a DataFrame that stores user factors in two columns: `id` and
279+
`features`
280+
"""
281+
return self._call_java("userFactors")
282+
283+
@property
284+
def itemFactors(self):
285+
"""
286+
a DataFrame that stores item factors in two columns: `id` and
287+
`features`
288+
"""
289+
return self._call_java("itemFactors")
290+
263291

264292
if __name__ == "__main__":
265293
import doctest
@@ -272,8 +300,6 @@ class ALSModel(JavaModel):
272300
sqlContext = SQLContext(sc)
273301
globs['sc'] = sc
274302
globs['sqlContext'] = sqlContext
275-
globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
276-
(2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
277303
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
278304
sc.stop()
279305
if failure_count:

python/pyspark/mllib/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pyspark import RDD, SparkContext
2929
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
30-
30+
from pyspark.sql import DataFrame, SQLContext
3131

3232
# Hack for support float('inf') in Py4j
3333
_old_smart_decode = py4j.protocol.smart_decode
@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
9999
jrdd = sc._jvm.SerDe.javaToPython(r)
100100
return RDD(jrdd, sc)
101101

102+
if clsName == 'DataFrame':
103+
return DataFrame(r, SQLContext(sc))
104+
102105
if clsName in _picklable_classes:
103106
r = sc._jvm.SerDe.dumps(r)
104107
elif isinstance(r, (JavaArray, JavaList)):

0 commit comments

Comments
 (0)