-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-19281][PYTHON][ML] spark.ml Python API for FPGrowth #17218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e6935eb
28d072b
a1920ab
d73f254
5260497
feb6af9
159f2ad
154b5ba
df6777f
90918c5
43c9dcc
aa45479
d4ae39a
6740581
33c8971
eb4ec26
3c7f4f7
3521d40
bdea0ff
deb2ce7
bf0a285
66b85e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from pyspark import keyword_only, since | ||
from pyspark.ml.util import * | ||
from pyspark.ml.wrapper import JavaEstimator, JavaModel | ||
from pyspark.ml.param.shared import * | ||
|
||
__all__ = ["FPGrowth", "FPGrowthModel"] | ||
|
||
|
||
class HasSupport(Params): | ||
""" | ||
Mixin for param support. | ||
""" | ||
|
||
minSupport = Param( | ||
Params._dummy(), | ||
"minSupport", | ||
"""Minimal support level of the frequent pattern. [0.0, 1.0]. | ||
Any pattern that appears more than (minSupport * size-of-the-dataset) | ||
times will be output""", | ||
typeConverter=TypeConverters.toFloat) | ||
|
||
def setMinSupport(self, value): | ||
""" | ||
Sets the value of :py:attr:`minSupport`. | ||
""" | ||
return self._set(minSupport=value) | ||
|
||
def getMinSupport(self): | ||
""" | ||
Gets the value of minSupport or its default value. | ||
""" | ||
return self.getOrDefault(self.minSupport) | ||
|
||
|
||
class HasConfidence(Params): | ||
""" | ||
Mixin for param confidence. | ||
""" | ||
|
||
minConfidence = Param( | ||
Params._dummy(), | ||
"minConfidence", | ||
"""Minimal confidence for generating Association Rule. [0.0, 1.0] | ||
Note that minConfidence has no effect during fitting.""", | ||
typeConverter=TypeConverters.toFloat) | ||
|
||
def setMinConfidence(self, value): | ||
""" | ||
Sets the value of :py:attr:`minConfidence`. | ||
""" | ||
return self._set(minConfidence=value) | ||
|
||
def getMinConfidence(self): | ||
""" | ||
Gets the value of minConfidence or its default value. | ||
""" | ||
return self.getOrDefault(self.minConfidence) | ||
|
||
|
||
class HasItemsCol(Params): | ||
""" | ||
Mixin for param itemsCol: items column name. | ||
""" | ||
|
||
itemsCol = Param(Params._dummy(), "itemsCol", | ||
"items column name", typeConverter=TypeConverters.toString) | ||
|
||
def setItemsCol(self, value): | ||
""" | ||
Sets the value of :py:attr:`itemsCol`. | ||
""" | ||
return self._set(itemsCol=value) | ||
|
||
def getItemsCol(self): | ||
""" | ||
Gets the value of itemsCol or its default value. | ||
""" | ||
return self.getOrDefault(self.itemsCol) | ||
|
||
|
||
class FPGrowthModel(JavaModel, JavaMLWritable, JavaMLReadable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark Experimental There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, it'd be good to be able to set minConfidence, itemsCol and predictionCol (for associationRules and transform) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pushed my first attempt but I think will require a bit more discussion. If enable this here should we do the same for the rest of Python models? |
||
""" | ||
.. note:: Experimental | ||
|
||
Model fitted by FPGrowth. | ||
|
||
.. versionadded:: 2.2.0 | ||
""" | ||
@property | ||
@since("2.2.0") | ||
def freqItemsets(self): | ||
""" | ||
DataFrame with two columns: | ||
* `items` - Itemset of the same type as the input column. | ||
* `freq` - Frequency of the itemset (`LongType`). | ||
""" | ||
return self._call_java("freqItemsets") | ||
|
||
@property | ||
@since("2.2.0") | ||
def associationRules(self): | ||
""" | ||
Data with three columns: | ||
* `antecedent` - Array of the same type as the input column. | ||
* `consequent` - Array of the same type as the input column. | ||
* `confidence` - Confidence for the rule (`DoubleType`). | ||
""" | ||
return self._call_java("associationRules") | ||
|
||
|
||
class FPGrowth(JavaEstimator, HasItemsCol, HasPredictionCol, | ||
HasSupport, HasConfidence, JavaMLWritable, JavaMLReadable): | ||
""" | ||
.. note:: Experimental | ||
|
||
A parallel FP-growth algorithm to mine frequent itemsets. The algorithm is described in | ||
Li et al., PFP: Parallel FP-Growth for Query Recommendation [LI2008]_. | ||
PFP distributes computation in such a way that each worker executes an | ||
independent group of mining tasks. The FP-Growth algorithm is described in | ||
Han et al., Mining frequent patterns without candidate generation [HAN2000]_ | ||
|
||
.. [LI2008] http://dx.doi.org/10.1145/1454008.1454027 | ||
.. [HAN2000] http://dx.doi.org/10.1145/335191.335372 | ||
|
||
.. note:: null values in the feature column are ignored during fit(). | ||
.. note:: Internally `transform` `collects` and `broadcasts` association rules. | ||
|
||
>>> from pyspark.sql.functions import split | ||
>>> data = (spark.read | ||
... .text("data/mllib/sample_fpgrowth.txt") | ||
... .select(split("value", "\s+").alias("items"))) | ||
>>> data.show(truncate=False) | ||
+------------------------+ | ||
|items | | ||
+------------------------+ | ||
|[r, z, h, k, p] | | ||
|[z, y, x, w, v, u, t, s]| | ||
|[s, x, o, n, r] | | ||
|[x, z, y, m, t, s, q, e]| | ||
|[z] | | ||
|[x, z, y, r, q, t, p] | | ||
+------------------------+ | ||
>>> fp = FPGrowth(minSupport=0.2, minConfidence=0.7) | ||
>>> fpm = fp.fit(data) | ||
>>> fpm.freqItemsets.show(5) | ||
+---------+----+ | ||
| items|freq| | ||
+---------+----+ | ||
| [s]| 3| | ||
| [s, x]| 3| | ||
|[s, x, z]| 2| | ||
| [s, z]| 2| | ||
| [r]| 3| | ||
+---------+----+ | ||
only showing top 5 rows | ||
>>> fpm.associationRules.show(5) | ||
+----------+----------+----------+ | ||
|antecedent|consequent|confidence| | ||
+----------+----------+----------+ | ||
| [t, s]| [y]| 1.0| | ||
| [t, s]| [x]| 1.0| | ||
| [t, s]| [z]| 1.0| | ||
| [p]| [r]| 1.0| | ||
| [p]| [z]| 1.0| | ||
+----------+----------+----------+ | ||
only showing top 5 rows | ||
>>> new_data = spark.createDataFrame([(["t", "s"], )], ["items"]) | ||
>>> sorted(fpm.transform(new_data).first().prediction) | ||
['x', 'y', 'z'] | ||
|
||
.. versionadded:: 2.2.0 | ||
""" | ||
@keyword_only | ||
def __init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", | ||
predictionCol="prediction", numPartitions=None): | ||
""" | ||
__init__(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ | ||
predictionCol="prediction", numPartitions=None) | ||
""" | ||
super(FPGrowth, self).__init__() | ||
self._java_obj = self._new_java_obj("org.apache.spark.ml.fpm.FPGrowth", self.uid) | ||
self._setDefault(minSupport=0.3, minConfidence=0.8, | ||
itemsCol="items", predictionCol="prediction") | ||
kwargs = self._input_kwargs | ||
self.setParams(**kwargs) | ||
|
||
@keyword_only | ||
@since("2.2.0") | ||
def setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", | ||
predictionCol="prediction", numPartitions=None): | ||
""" | ||
setParams(self, minSupport=0.3, minConfidence=0.8, itemsCol="items", \ | ||
predictionCol="prediction", numPartitions=None) | ||
""" | ||
kwargs = self._input_kwargs | ||
return self._set(**kwargs) | ||
|
||
def _create_model(self, java_model): | ||
return FPGrowthModel(java_model) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,26 +42,28 @@ | |
import array as pyarray | ||
import numpy as np | ||
from numpy import ( | ||
array, array_equal, zeros, inf, random, exp, dot, all, mean, abs, arange, tile, ones) | ||
abs, all, arange, array, array_equal, dot, exp, inf, mean, ones, random, tile, zeros) | ||
from numpy import sum as array_sum | ||
import inspect | ||
|
||
from pyspark import keyword_only, SparkContext | ||
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer | ||
from pyspark.ml.classification import * | ||
from pyspark.ml.clustering import * | ||
from pyspark.ml.common import _java2py, _py2java | ||
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator | ||
from pyspark.ml.feature import * | ||
from pyspark.ml.linalg import Vector, SparseVector, DenseVector, VectorUDT,\ | ||
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT, _convert_to_vector | ||
from pyspark.ml.fpm import FPGrowth, FPGrowthModel | ||
from pyspark.ml.linalg import ( | ||
DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT, | ||
SparseMatrix, SparseVector, Vector, VectorUDT, Vectors, _convert_to_vector) | ||
from pyspark.ml.param import Param, Params, TypeConverters | ||
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed | ||
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed | ||
from pyspark.ml.recommendation import ALS | ||
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ | ||
GeneralizedLinearRegression | ||
from pyspark.ml.regression import ( | ||
DecisionTreeRegressor, GeneralizedLinearRegression, LinearRegression) | ||
from pyspark.ml.tuning import * | ||
from pyspark.ml.wrapper import JavaParams, JavaWrapper | ||
from pyspark.ml.common import _java2py, _py2java | ||
from pyspark.serializers import PickleSerializer | ||
from pyspark.sql import DataFrame, Row, SparkSession | ||
from pyspark.sql.functions import rand | ||
|
@@ -1243,6 +1245,43 @@ def test_tweedie_distribution(self): | |
self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) | ||
|
||
|
||
class FPGrowthTests(SparkSessionTestCase): | ||
def setUp(self): | ||
super(FPGrowthTests, self).setUp() | ||
self.data = self.spark.createDataFrame( | ||
[([1, 2], ), ([1, 2], ), ([1, 2, 3], ), ([1, 3], )], | ||
["items"]) | ||
|
||
def test_association_rules(self): | ||
fp = FPGrowth() | ||
fpm = fp.fit(self.data) | ||
|
||
expected_association_rules = self.spark.createDataFrame( | ||
[([3], [1], 1.0), ([2], [1], 1.0)], | ||
["antecedent", "consequent", "confidence"] | ||
) | ||
actual_association_rules = fpm.associationRules | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try inserting
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's from not calling the parent setUp and tearDown There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. I reported this before on the developers list (http://apache-spark-developers-list.1001551.n3.nabble.com/ML-PYTHON-Collecting-data-in-a-class-extending-SparkSessionTestCase-causes-AttributeError-td21120.html) with a minimal example. There is something ugly going on here but it doesn't seem to be related to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also on a clean build I get a lot of exceptions
when running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @davies is looking into this |
||
self.assertEqual(actual_association_rules.subtract(expected_association_rules).count(), 0) | ||
self.assertEqual(expected_association_rules.subtract(actual_association_rules).count(), 0) | ||
|
||
def test_freq_itemsets(self): | ||
fp = FPGrowth() | ||
fpm = fp.fit(self.data) | ||
|
||
expected_freq_itemsets = self.spark.createDataFrame( | ||
[([1], 4), ([2], 3), ([2, 1], 3), ([3], 2), ([3, 1], 2)], | ||
["items", "freq"] | ||
) | ||
actual_freq_itemsets = fpm.freqItemsets | ||
|
||
self.assertEqual(actual_freq_itemsets.subtract(expected_freq_itemsets).count(), 0) | ||
self.assertEqual(expected_freq_itemsets.subtract(actual_freq_itemsets).count(), 0) | ||
|
||
def tearDown(self): | ||
del self.data | ||
|
||
|
||
class ALSTest(SparkSessionTestCase): | ||
|
||
def test_storage_levels(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as you're at it, switch tuning & tests to alphabetize them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing, I thought there is some logic in putting tests last. Should I reorder the other modules as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting...maybe? I guess it doesn't really matter, so no need to rearrange more.