Skip to content

Commit cf59634

Browse files
committed
add pyspark rformula save/load
1 parent 1614485 commit cf59634

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

python/pyspark/ml/feature.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2360,7 +2360,7 @@ def explainedVariance(self):
23602360

23612361

23622362
@inherit_doc
2363-
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
2363+
class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
23642364
"""
23652365
.. note:: Experimental
23662366
@@ -2385,7 +2385,31 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
23852385
|0.0|0.0| a|[0.0,1.0]| 0.0|
23862386
+---+---+---+---------+-----+
23872387
...
2388-
>>> rf.fit(df, {rf.formula: "y ~ . - s"}).transform(df).show()
2388+
>>> model = rf.fit(df, {rf.formula: "y ~ . - s"})
2389+
>>> model.transform(df).show()
2390+
+---+---+---+--------+-----+
2391+
| y| x| s|features|label|
2392+
+---+---+---+--------+-----+
2393+
|1.0|1.0| a| [1.0]| 1.0|
2394+
|0.0|2.0| b| [2.0]| 0.0|
2395+
|0.0|0.0| a| [0.0]| 0.0|
2396+
+---+---+---+--------+-----+
2397+
...
2398+
>>> rFormulaPath = temp_path + "/rFormula"
2399+
>>> rf.save(rFormulaPath)
2400+
>>> loadedRF = RFormula.load(rFormulaPath)
2401+
>>> loadedRF.getFormula() == rf.getFormula()
2402+
True
2403+
>>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()
2404+
True
2405+
>>> loadedRF.getLabelCol() == rf.getLabelCol()
2406+
True
2407+
>>> modelPath = temp_path + "/rFormula-model"
2408+
>>> model.save(modelPath)
2409+
>>> loadedModel = RFormulaModel.load(modelPath)
2410+
>>> loadedModel.uid == model.uid
2411+
True
2412+
>>> loadedModel.transform(df).show()
23892413
+---+---+---+--------+-----+
23902414
| y| x| s|features|label|
23912415
+---+---+---+--------+-----+
@@ -2439,7 +2463,7 @@ def _create_model(self, java_model):
24392463
return RFormulaModel(java_model)
24402464

24412465

2442-
class RFormulaModel(JavaModel):
2466+
class RFormulaModel(JavaModel, MLReadable, MLWritable):
24432467
"""
24442468
.. note:: Experimental
24452469

0 commit comments

Comments
 (0)