Skip to content

Commit 036ca04

Browse files
committed
gen numFeatures
1 parent 46fa147 commit 036ca04

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

python/pyspark/ml/feature.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
from pyspark.sql import inherit_doc
1919
from pyspark.ml import JavaTransformer
20-
from pyspark.ml.param import Param
21-
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
20+
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
2221

2322

2423
@inherit_doc
@@ -33,23 +32,11 @@ def _java_class(self):
3332

3433

3534
@inherit_doc
36-
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol):
35+
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
3736

3837
def __init__(self):
3938
super(HashingTF, self).__init__()
40-
#: param for number of features
41-
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
4239

4340
@property
4441
def _java_class(self):
4542
return "org.apache.spark.ml.feature.HashingTF"
46-
47-
def setNumFeatures(self, value):
48-
self.paramMap[self.numFeatures] = value
49-
return self
50-
51-
def getNumFeatures(self):
52-
if self.numFeatures in self.paramMap:
53-
return self.paramMap[self.numFeatures]
54-
else:
55-
return self.numFeatures.defaultValue

python/pyspark/ml/param/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, parent, name, doc, defaultValue=None):
3636
self.defaultValue = defaultValue
3737

3838
def __str__(self):
39-
return str(self.parent) + "_" + self.name
39+
return str(self.parent) + "-" + self.name
4040

4141
def __repr__(self):
4242
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \

python/pyspark/ml/param/_gen_shared_params.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def get$Name(self):
9090
("labelCol", "label column name", "'label'"),
9191
("predictionCol", "prediction column name", "'prediction'"),
9292
("inputCol", "input column name", "'input'"),
93-
("outputCol", "output column name", "'output'")]
93+
("outputCol", "output column name", "'output'"),
94+
("numFeatures", "number of features", "1 << 18")]
9495
code = []
9596
for name, doc, defaultValue in shared:
9697
code.append(_gen_param_code(name, doc, defaultValue))

python/pyspark/ml/param/shared.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,33 @@ def getOutputCol(self):
228228
return self.paramMap[self.outputCol]
229229
else:
230230
return self.outputCol.defaultValue
231+
232+
233+
class HasNumFeatures(Params):
234+
"""
235+
Params with numFeatures.
236+
"""
237+
238+
# a placeholder to make it appear in the generated doc
239+
numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
240+
241+
def __init__(self):
242+
super(HasNumFeatures, self).__init__()
243+
#: param for number of features
244+
self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
245+
246+
def setNumFeatures(self, value):
247+
"""
248+
Sets the value of :py:attr:`numFeatures`.
249+
"""
250+
self.paramMap[self.numFeatures] = value
251+
return self
252+
253+
def getNumFeatures(self):
254+
"""
255+
Gets the value of numFeatures or its default value.
256+
"""
257+
if self.numFeatures in self.paramMap:
258+
return self.paramMap[self.numFeatures]
259+
else:
260+
return self.numFeatures.defaultValue

0 commit comments

Comments
 (0)