Skip to content

Commit 133b1b6

Browse files
committed
[SPARK-51473][ML][CONNECT] ML transformed dataframe keep a reference to the model
### What changes were proposed in this pull request? add the model link in the transformed dataframe ### Why are the changes needed? apache#49948 disabled the model GC for `fit_transform`, this PR add the model link in the transformed dataframe, so that the model will be GCed together with the transformed dataframe ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing test should cover this change ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#50199 from zhengruifeng/ml_connect_model_ref. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent cfe1b39 commit 133b1b6

File tree

5 files changed

+216
-57
lines changed

5 files changed

+216
-57
lines changed

python/pyspark/ml/classification.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -889,7 +889,10 @@ def summary(self) -> "LinearSVCTrainingSummary": # type: ignore[override]
889889
trained on the training set. An exception is thrown if `trainingSummary is None`.
890890
"""
891891
if self.hasSummary:
892-
return LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
892+
s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary)
893+
if is_remote():
894+
s.__source_transformer__ = self # type: ignore[attr-defined]
895+
return s
893896
else:
894897
raise RuntimeError(
895898
"No training summary available for this %s" % self.__class__.__name__
@@ -909,7 +912,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary":
909912
if not isinstance(dataset, DataFrame):
910913
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
911914
java_lsvc_summary = self._call_java("evaluate", dataset)
912-
return LinearSVCSummary(java_lsvc_summary)
915+
s = LinearSVCSummary(java_lsvc_summary)
916+
if is_remote():
917+
s.__source_transformer__ = self # type: ignore[attr-defined]
918+
return s
913919

914920

915921
class LinearSVCSummary(_BinaryClassificationSummary):
@@ -1578,14 +1584,16 @@ def summary(self) -> "LogisticRegressionTrainingSummary":
15781584
trained on the training set. An exception is thrown if `trainingSummary is None`.
15791585
"""
15801586
if self.hasSummary:
1587+
s: LogisticRegressionTrainingSummary
15811588
if self.numClasses <= 2:
1582-
return BinaryLogisticRegressionTrainingSummary(
1589+
s = BinaryLogisticRegressionTrainingSummary(
15831590
super(LogisticRegressionModel, self).summary
15841591
)
15851592
else:
1586-
return LogisticRegressionTrainingSummary(
1587-
super(LogisticRegressionModel, self).summary
1588-
)
1593+
s = LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary)
1594+
if is_remote():
1595+
s.__source_transformer__ = self # type: ignore[attr-defined]
1596+
return s
15891597
else:
15901598
raise RuntimeError(
15911599
"No training summary available for this %s" % self.__class__.__name__
@@ -1605,10 +1613,14 @@ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary":
16051613
if not isinstance(dataset, DataFrame):
16061614
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
16071615
java_blr_summary = self._call_java("evaluate", dataset)
1616+
s: LogisticRegressionSummary
16081617
if self.numClasses <= 2:
1609-
return BinaryLogisticRegressionSummary(java_blr_summary)
1618+
s = BinaryLogisticRegressionSummary(java_blr_summary)
16101619
else:
1611-
return LogisticRegressionSummary(java_blr_summary)
1620+
s = LogisticRegressionSummary(java_blr_summary)
1621+
if is_remote():
1622+
s.__source_transformer__ = self # type: ignore[attr-defined]
1623+
return s
16121624

16131625

16141626
class LogisticRegressionSummary(_ClassificationSummary):
@@ -2304,22 +2316,24 @@ def summary(self) -> "RandomForestClassificationTrainingSummary":
23042316
trained on the training set. An exception is thrown if `trainingSummary is None`.
23052317
"""
23062318
if self.hasSummary:
2319+
s: RandomForestClassificationTrainingSummary
23072320
if self.numClasses <= 2:
2308-
return BinaryRandomForestClassificationTrainingSummary(
2321+
s = BinaryRandomForestClassificationTrainingSummary(
23092322
super(RandomForestClassificationModel, self).summary
23102323
)
23112324
else:
2312-
return RandomForestClassificationTrainingSummary(
2325+
s = RandomForestClassificationTrainingSummary(
23132326
super(RandomForestClassificationModel, self).summary
23142327
)
2328+
if is_remote():
2329+
s.__source_transformer__ = self # type: ignore[attr-defined]
2330+
return s
23152331
else:
23162332
raise RuntimeError(
23172333
"No training summary available for this %s" % self.__class__.__name__
23182334
)
23192335

2320-
def evaluate(
2321-
self, dataset: DataFrame
2322-
) -> Union["BinaryRandomForestClassificationSummary", "RandomForestClassificationSummary"]:
2336+
def evaluate(self, dataset: DataFrame) -> "RandomForestClassificationSummary":
23232337
"""
23242338
Evaluates the model on a test dataset.
23252339
@@ -2333,10 +2347,14 @@ def evaluate(
23332347
if not isinstance(dataset, DataFrame):
23342348
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
23352349
java_rf_summary = self._call_java("evaluate", dataset)
2350+
s: RandomForestClassificationSummary
23362351
if self.numClasses <= 2:
2337-
return BinaryRandomForestClassificationSummary(java_rf_summary)
2352+
s = BinaryRandomForestClassificationSummary(java_rf_summary)
23382353
else:
2339-
return RandomForestClassificationSummary(java_rf_summary)
2354+
s = RandomForestClassificationSummary(java_rf_summary)
2355+
if is_remote():
2356+
s.__source_transformer__ = self # type: ignore[attr-defined]
2357+
return s
23402358

23412359

23422360
class RandomForestClassificationSummary(_ClassificationSummary):
@@ -2363,7 +2381,10 @@ class RandomForestClassificationTrainingSummary(
23632381

23642382

23652383
@inherit_doc
2366-
class BinaryRandomForestClassificationSummary(_BinaryClassificationSummary):
2384+
class BinaryRandomForestClassificationSummary(
2385+
_BinaryClassificationSummary,
2386+
RandomForestClassificationSummary,
2387+
):
23672388
"""
23682389
BinaryRandomForestClassification results for a given model.
23692390
@@ -3341,9 +3362,12 @@ def summary( # type: ignore[override]
33413362
trained on the training set. An exception is thrown if `trainingSummary is None`.
33423363
"""
33433364
if self.hasSummary:
3344-
return MultilayerPerceptronClassificationTrainingSummary(
3365+
s = MultilayerPerceptronClassificationTrainingSummary(
33453366
super(MultilayerPerceptronClassificationModel, self).summary
33463367
)
3368+
if is_remote():
3369+
s.__source_transformer__ = self # type: ignore[attr-defined]
3370+
return s
33473371
else:
33483372
raise RuntimeError(
33493373
"No training summary available for this %s" % self.__class__.__name__
@@ -3363,7 +3387,10 @@ def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSum
33633387
if not isinstance(dataset, DataFrame):
33643388
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
33653389
java_mlp_summary = self._call_java("evaluate", dataset)
3366-
return MultilayerPerceptronClassificationSummary(java_mlp_summary)
3390+
s = MultilayerPerceptronClassificationSummary(java_mlp_summary)
3391+
if is_remote():
3392+
s.__source_transformer__ = self # type: ignore[attr-defined]
3393+
return s
33673394

33683395

33693396
class MultilayerPerceptronClassificationSummary(_ClassificationSummary):
@@ -4290,7 +4317,10 @@ def summary(self) -> "FMClassificationTrainingSummary":
42904317
trained on the training set. An exception is thrown if `trainingSummary is None`.
42914318
"""
42924319
if self.hasSummary:
4293-
return FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
4320+
s = FMClassificationTrainingSummary(super(FMClassificationModel, self).summary)
4321+
if is_remote():
4322+
s.__source_transformer__ = self # type: ignore[attr-defined]
4323+
return s
42944324
else:
42954325
raise RuntimeError(
42964326
"No training summary available for this %s" % self.__class__.__name__
@@ -4310,7 +4340,10 @@ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary":
43104340
if not isinstance(dataset, DataFrame):
43114341
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
43124342
java_fm_summary = self._call_java("evaluate", dataset)
4313-
return FMClassificationSummary(java_fm_summary)
4343+
s = FMClassificationSummary(java_fm_summary)
4344+
if is_remote():
4345+
s.__source_transformer__ = self # type: ignore[attr-defined]
4346+
return s
43144347

43154348

43164349
class FMClassificationSummary(_BinaryClassificationSummary):

python/pyspark/ml/clustering.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,10 @@ def summary(self) -> "GaussianMixtureSummary":
263263
training set. An exception is thrown if no summary exists.
264264
"""
265265
if self.hasSummary:
266-
return GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
266+
s = GaussianMixtureSummary(super(GaussianMixtureModel, self).summary)
267+
if is_remote():
268+
s.__source_transformer__ = self # type: ignore[attr-defined]
269+
return s
267270
else:
268271
raise RuntimeError(
269272
"No training summary available for this %s" % self.__class__.__name__
@@ -710,7 +713,10 @@ def summary(self) -> KMeansSummary:
710713
training set. An exception is thrown if no summary exists.
711714
"""
712715
if self.hasSummary:
713-
return KMeansSummary(super(KMeansModel, self).summary)
716+
s = KMeansSummary(super(KMeansModel, self).summary)
717+
if is_remote():
718+
s.__source_transformer__ = self # type: ignore[attr-defined]
719+
return s
714720
else:
715721
raise RuntimeError(
716722
"No training summary available for this %s" % self.__class__.__name__
@@ -1057,7 +1063,10 @@ def summary(self) -> "BisectingKMeansSummary":
10571063
training set. An exception is thrown if no summary exists.
10581064
"""
10591065
if self.hasSummary:
1060-
return BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
1066+
s = BisectingKMeansSummary(super(BisectingKMeansModel, self).summary)
1067+
if is_remote():
1068+
s.__source_transformer__ = self # type: ignore[attr-defined]
1069+
return s
10611070
else:
10621071
raise RuntimeError(
10631072
"No training summary available for this %s" % self.__class__.__name__

python/pyspark/ml/regression.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,10 @@ def summary(self) -> "LinearRegressionTrainingSummary":
487487
`trainingSummary is None`.
488488
"""
489489
if self.hasSummary:
490-
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
490+
s = LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
491+
if is_remote():
492+
s.__source_transformer__ = self # type: ignore[attr-defined]
493+
return s
491494
else:
492495
raise RuntimeError(
493496
"No training summary available for this %s" % self.__class__.__name__
@@ -508,7 +511,10 @@ def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary":
508511
if not isinstance(dataset, DataFrame):
509512
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
510513
java_lr_summary = self._call_java("evaluate", dataset)
511-
return LinearRegressionSummary(java_lr_summary)
514+
s = LinearRegressionSummary(java_lr_summary)
515+
if is_remote():
516+
s.__source_transformer__ = self # type: ignore[attr-defined]
517+
return s
512518

513519

514520
class LinearRegressionSummary(JavaWrapper):
@@ -2766,9 +2772,12 @@ def summary(self) -> "GeneralizedLinearRegressionTrainingSummary":
27662772
`trainingSummary is None`.
27672773
"""
27682774
if self.hasSummary:
2769-
return GeneralizedLinearRegressionTrainingSummary(
2775+
s = GeneralizedLinearRegressionTrainingSummary(
27702776
super(GeneralizedLinearRegressionModel, self).summary
27712777
)
2778+
if is_remote():
2779+
s.__source_transformer__ = self # type: ignore[attr-defined]
2780+
return s
27722781
else:
27732782
raise RuntimeError(
27742783
"No training summary available for this %s" % self.__class__.__name__
@@ -2789,7 +2798,10 @@ def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary":
27892798
if not isinstance(dataset, DataFrame):
27902799
raise TypeError("dataset must be a DataFrame but got %s." % type(dataset))
27912800
java_glr_summary = self._call_java("evaluate", dataset)
2792-
return GeneralizedLinearRegressionSummary(java_glr_summary)
2801+
s = GeneralizedLinearRegressionSummary(java_glr_summary)
2802+
if is_remote():
2803+
s.__source_transformer__ = self # type: ignore[attr-defined]
2804+
return s
27932805

27942806

27952807
class GeneralizedLinearRegressionSummary(JavaWrapper):

python/pyspark/ml/tests/test_pipeline.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030
from pyspark.ml.linalg import Vectors
3131
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel
32-
from pyspark.ml.clustering import KMeans, KMeansModel
32+
from pyspark.ml.clustering import KMeans, KMeansModel, GaussianMixture
3333
from pyspark.testing.mlutils import MockDataset, MockEstimator, MockTransformer
3434
from pyspark.testing.sqlutils import ReusedSQLTestCase
3535

@@ -176,7 +176,7 @@ def test_clustering_pipeline(self):
176176

177177
def test_model_gc(self):
178178
spark = self.spark
179-
df = spark.createDataFrame(
179+
df1 = spark.createDataFrame(
180180
[
181181
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
182182
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
@@ -189,8 +189,107 @@ def fit_transform(df):
189189
model = lr.fit(df)
190190
return model.transform(df)
191191

192-
output = fit_transform(df)
193-
self.assertEqual(output.count(), 3)
192+
output1 = fit_transform(df1)
193+
self.assertEqual(output1.count(), 3)
194+
195+
df2 = spark.range(10)
196+
197+
def fit_transform_and_union(df1, df2):
198+
output1 = fit_transform(df1)
199+
return output1.unionByName(df2, True)
200+
201+
output2 = fit_transform_and_union(df1, df2)
202+
self.assertEqual(output2.count(), 13)
203+
204+
def test_model_training_summary_gc(self):
205+
spark = self.spark
206+
df1 = spark.createDataFrame(
207+
[
208+
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
209+
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
210+
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
211+
]
212+
)
213+
214+
def fit_predictions(df):
215+
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
216+
model = lr.fit(df)
217+
return model.summary.predictions
218+
219+
output1 = fit_predictions(df1)
220+
self.assertEqual(output1.count(), 3)
221+
222+
df2 = spark.range(10)
223+
224+
def fit_predictions_and_union(df1, df2):
225+
output1 = fit_predictions(df1)
226+
return output1.unionByName(df2, True)
227+
228+
output2 = fit_predictions_and_union(df1, df2)
229+
self.assertEqual(output2.count(), 13)
230+
231+
def test_model_testing_summary_gc(self):
232+
spark = self.spark
233+
df1 = spark.createDataFrame(
234+
[
235+
Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
236+
Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
237+
Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
238+
]
239+
)
240+
241+
def fit_predictions(df):
242+
lr = LogisticRegression(maxIter=1, regParam=0.01, weightCol="weight")
243+
model = lr.fit(df)
244+
return model.evaluate(df).predictions
245+
246+
output1 = fit_predictions(df1)
247+
self.assertEqual(output1.count(), 3)
248+
249+
df2 = spark.range(10)
250+
251+
def fit_predictions_and_union(df1, df2):
252+
output1 = fit_predictions(df1)
253+
return output1.unionByName(df2, True)
254+
255+
output2 = fit_predictions_and_union(df1, df2)
256+
self.assertEqual(output2.count(), 13)
257+
258+
def test_model_attr_df_gc(self):
259+
spark = self.spark
260+
df1 = (
261+
spark.createDataFrame(
262+
[
263+
(1, 1.0, Vectors.dense([-0.1, -0.05])),
264+
(2, 2.0, Vectors.dense([-0.01, -0.1])),
265+
(3, 3.0, Vectors.dense([0.9, 0.8])),
266+
(4, 1.0, Vectors.dense([0.75, 0.935])),
267+
(5, 1.0, Vectors.dense([-0.83, -0.68])),
268+
(6, 1.0, Vectors.dense([-0.91, -0.76])),
269+
],
270+
["index", "weight", "features"],
271+
)
272+
.coalesce(1)
273+
.sortWithinPartitions("index")
274+
.select("weight", "features")
275+
)
276+
277+
def fit_attr_df(df):
278+
gmm = GaussianMixture(k=2, maxIter=2, weightCol="weight", seed=1)
279+
model = gmm.fit(df)
280+
return model.gaussiansDF
281+
282+
output1 = fit_attr_df(df1)
283+
self.assertEqual(output1.count(), 2)
284+
285+
df2 = spark.range(10)
286+
287+
def fit_attr_df_and_union(df1, df2):
288+
output1 = fit_attr_df(df1)
289+
return output1.unionByName(df2, True)
290+
291+
output2 = fit_attr_df_and_union(df1, df2)
292+
self.assertEqual(output2.count(), 12)
194293

195294

196295
class PipelineTests(PipelineTestsMixin, ReusedSQLTestCase):

0 commit comments

Comments
 (0)