Skip to content

Commit c0f6db8

Browse files
committed
[SPARK-48883][ML][R] Replace RDD read / write API invocation with Dataframe read / write API
### What changes were proposed in this pull request? This PR is a retry of #47328 which replaces RDD to Dataset to write SparkR metadata plus this PR removes `repartition(1)`. We actually don't need this when the input is single row as it creates only single partition: https://github.com/apache/spark/blob/e5e751b98f9ef5b8640079c07a9a342ef471d75d/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala#L49-L57 ### Why are the changes needed? In order to leverage Catalyst optimizer and SQL engine. For example, now we leverage UTF-8 encoding instead of plain JDK ser/de for strings. We have made similar changes in the past, e.g., #29063, #15813, #17255 and SPARK-19918. Also, we remove `repartition(1)`. To avoid unnecessary shuffle. With `repartition(1)`: ``` == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- Exchange SinglePartition, REPARTITION_BY_NUM, [plan_id=6] +- LocalTableScan [_1#0] ``` Without `repartition(1)`: ``` == Physical Plan == LocalTableScan [_1#2] ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI in this PR should verify the change ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47341 from HyukjinKwon/SPARK-48883-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 3755d51 commit c0f6db8

23 files changed

+110
-44
lines changed

mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
129129
val rMetadata = ("class" -> instance.getClass.getName) ~
130130
("features" -> instance.features.toImmutableArraySeq)
131131
val rMetadataJson: String = compact(render(rMetadata))
132-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
132+
// Note that we should write single file. If there are more than one row
133+
// it produces more partitions.
134+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
133135

134136
instance.pipeline.save(pipelinePath)
135137
}
@@ -142,7 +144,8 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
142144
val rMetadataPath = new Path(path, "rMetadata").toString
143145
val pipelinePath = new Path(path, "pipeline").toString
144146

145-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
147+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
148+
.first().getString(0)
146149
val rMetadata = parse(rMetadataStr)
147150
val features = (rMetadata \ "features").extract[Array[String]]
148151

mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
9494
val rMetadata = ("class" -> instance.getClass.getName) ~
9595
("ratingCol" -> instance.ratingCol)
9696
val rMetadataJson: String = compact(render(rMetadata))
97-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
97+
// Note that we should write single file. If there are more than one row
98+
// it produces more partitions.
99+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
98100

99101
instance.alsModel.save(modelPath)
100102
}
@@ -107,7 +109,8 @@ private[r] object ALSWrapper extends MLReadable[ALSWrapper] {
107109
val rMetadataPath = new Path(path, "rMetadata").toString
108110
val modelPath = new Path(path, "model").toString
109111

110-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
112+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
113+
.first().getString(0)
111114
val rMetadata = parse(rMetadataStr)
112115
val ratingCol = (rMetadata \ "ratingCol").extract[String]
113116
val alsModel = ALSModel.load(modelPath)

mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapp
120120
("size" -> instance.size.toImmutableArraySeq)
121121
val rMetadataJson: String = compact(render(rMetadata))
122122

123-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
123+
// Note that we should write single file. If there are more than one row
124+
// it produces more partitions.
125+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
124126
instance.pipeline.save(pipelinePath)
125127
}
126128
}
@@ -133,7 +135,8 @@ private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapp
133135
val pipelinePath = new Path(path, "pipeline").toString
134136
val pipeline = PipelineModel.load(pipelinePath)
135137

136-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
138+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
139+
.first().getString(0)
137140
val rMetadata = parse(rMetadataStr)
138141
val features = (rMetadata \ "features").extract[Array[String]]
139142
val size = (rMetadata \ "size").extract[Array[Long]]

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
131131
("features" -> instance.features.toImmutableArraySeq)
132132
val rMetadataJson: String = compact(render(rMetadata))
133133

134-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
134+
// Note that we should write single file. If there are more than one row
135+
// it produces more partitions.
136+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
135137
instance.pipeline.save(pipelinePath)
136138
}
137139
}
@@ -144,7 +146,8 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC
144146
val pipelinePath = new Path(path, "pipeline").toString
145147
val pipeline = PipelineModel.load(pipelinePath)
146148

147-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
149+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
150+
.first().getString(0)
148151
val rMetadata = parse(rMetadataStr)
149152
val formula = (rMetadata \ "formula").extract[String]
150153
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRe
114114
("features" -> instance.features.toImmutableArraySeq)
115115
val rMetadataJson: String = compact(render(rMetadata))
116116

117-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
117+
// Note that we should write single file. If there are more than one row
118+
// it produces more partitions.
119+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
118120
instance.pipeline.save(pipelinePath)
119121
}
120122
}
@@ -127,7 +129,8 @@ private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRe
127129
val pipelinePath = new Path(path, "pipeline").toString
128130
val pipeline = PipelineModel.load(pipelinePath)
129131

130-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
132+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
133+
.first().getString(0)
131134
val rMetadata = parse(rMetadataStr)
132135
val formula = (rMetadata \ "formula").extract[String]
133136
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/FMClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ private[r] object FMClassifierWrapper
151151
("features" -> instance.features.toImmutableArraySeq) ~
152152
("labels" -> instance.labels.toImmutableArraySeq)
153153
val rMetadataJson: String = compact(render(rMetadata))
154-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
154+
// Note that we should write single file. If there are more than one row
155+
// it produces more partitions.
156+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
155157

156158
instance.pipeline.save(pipelinePath)
157159
}
@@ -164,7 +166,8 @@ private[r] object FMClassifierWrapper
164166
val rMetadataPath = new Path(path, "rMetadata").toString
165167
val pipelinePath = new Path(path, "pipeline").toString
166168

167-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
169+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
170+
.first().getString(0)
168171
val rMetadata = parse(rMetadataStr)
169172
val features = (rMetadata \ "features").extract[Array[String]]
170173
val labels = (rMetadata \ "labels").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/FMRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ private[r] object FMRegressorWrapper
132132
val rMetadata = ("class" -> instance.getClass.getName) ~
133133
("features" -> instance.features.toImmutableArraySeq)
134134
val rMetadataJson: String = compact(render(rMetadata))
135-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
135+
// Note that we should write single file. If there are more than one row
136+
// it produces more partitions.
137+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
136138

137139
instance.pipeline.save(pipelinePath)
138140
}
@@ -145,7 +147,8 @@ private[r] object FMRegressorWrapper
145147
val rMetadataPath = new Path(path, "rMetadata").toString
146148
val pipelinePath = new Path(path, "pipeline").toString
147149

148-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
150+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
151+
.first().getString(0)
149152
val rMetadata = parse(rMetadataStr)
150153
val features = (rMetadata \ "features").extract[Array[String]]
151154

mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] {
7777
val rMetadataJson: String = compact(render(
7878
"class" -> instance.getClass.getName
7979
))
80-
81-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
80+
// Note that we should write single file. If there are more than one row
81+
// it produces more partitions.
82+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
8283

8384
instance.fpGrowthModel.save(modelPath)
8485
}

mllib/src/main/scala/org/apache/spark/ml/r/GBTClassifierWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
138138
("features" -> instance.features.toImmutableArraySeq)
139139
val rMetadataJson: String = compact(render(rMetadata))
140140

141-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
141+
sparkSession.createDataFrame(
142+
Seq(Tuple1(rMetadataJson))
143+
).repartition(1).write.text(rMetadataPath)
142144
instance.pipeline.save(pipelinePath)
143145
}
144146
}
@@ -151,7 +153,8 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
151153
val pipelinePath = new Path(path, "pipeline").toString
152154
val pipeline = PipelineModel.load(pipelinePath)
153155

154-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
156+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
157+
.first().getString(0)
155158
val rMetadata = parse(rMetadataStr)
156159
val formula = (rMetadata \ "formula").extract[String]
157160
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressorWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
122122
("features" -> instance.features.toImmutableArraySeq)
123123
val rMetadataJson: String = compact(render(rMetadata))
124124

125-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
125+
// Note that we should write single file. If there are more than one row
126+
// it produces more partitions.
127+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
126128
instance.pipeline.save(pipelinePath)
127129
}
128130
}
@@ -135,7 +137,8 @@ private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] {
135137
val pipelinePath = new Path(path, "pipeline").toString
136138
val pipeline = PipelineModel.load(pipelinePath)
137139

138-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
140+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
141+
.first().getString(0)
139142
val rMetadata = parse(rMetadataStr)
140143
val formula = (rMetadata \ "formula").extract[String]
141144
val features = (rMetadata \ "features").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
113113
("logLikelihood" -> instance.logLikelihood)
114114
val rMetadataJson: String = compact(render(rMetadata))
115115

116-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
116+
// Note that we should write single file. If there are more than one row
117+
// it produces more partitions.
118+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
117119
instance.pipeline.save(pipelinePath)
118120
}
119121
}
@@ -126,7 +128,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
126128
val pipelinePath = new Path(path, "pipeline").toString
127129
val pipeline = PipelineModel.load(pipelinePath)
128130

129-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
131+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
132+
.first().getString(0)
130133
val rMetadata = parse(rMetadataStr)
131134
val dim = (rMetadata \ "dim").extract[Int]
132135
val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ private[r] object GeneralizedLinearRegressionWrapper
170170
("rAic" -> instance.rAic) ~
171171
("rNumIterations" -> instance.rNumIterations)
172172
val rMetadataJson: String = compact(render(rMetadata))
173-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
173+
// Note that we should write single file. If there are more than one row
174+
// it produces more partitions.
175+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
174176

175177
instance.pipeline.save(pipelinePath)
176178
}
@@ -184,7 +186,8 @@ private[r] object GeneralizedLinearRegressionWrapper
184186
val rMetadataPath = new Path(path, "rMetadata").toString
185187
val pipelinePath = new Path(path, "pipeline").toString
186188

187-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
189+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
190+
.first().getString(0)
188191
val rMetadata = parse(rMetadataStr)
189192
val rFeatures = (rMetadata \ "rFeatures").extract[Array[String]]
190193
val rCoefficients = (rMetadata \ "rCoefficients").extract[Array[Double]]

mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ private[r] object IsotonicRegressionWrapper
9999
val rMetadata = ("class" -> instance.getClass.getName) ~
100100
("features" -> instance.features.toImmutableArraySeq)
101101
val rMetadataJson: String = compact(render(rMetadata))
102-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
102+
// Note that we should write single file. If there are more than one row
103+
// it produces more partitions.
104+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
103105

104106
instance.pipeline.save(pipelinePath)
105107
}
@@ -112,7 +114,8 @@ private[r] object IsotonicRegressionWrapper
112114
val rMetadataPath = new Path(path, "rMetadata").toString
113115
val pipelinePath = new Path(path, "pipeline").toString
114116

115-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
117+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
118+
.first().getString(0)
116119
val rMetadata = parse(rMetadataStr)
117120
val features = (rMetadata \ "features").extract[Array[String]]
118121

mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
123123
("size" -> instance.size.toImmutableArraySeq)
124124
val rMetadataJson: String = compact(render(rMetadata))
125125

126-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
126+
// Note that we should write single file. If there are more than one row
127+
// it produces more partitions.
128+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
127129
instance.pipeline.save(pipelinePath)
128130
}
129131
}
@@ -136,7 +138,8 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
136138
val pipelinePath = new Path(path, "pipeline").toString
137139
val pipeline = PipelineModel.load(pipelinePath)
138140

139-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
141+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
142+
.first().getString(0)
140143
val rMetadata = parse(rMetadataStr)
141144
val features = (rMetadata \ "features").extract[Array[String]]
142145
val size = (rMetadata \ "size").extract[Array[Long]]

mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
198198
("logPerplexity" -> instance.logPerplexity) ~
199199
("vocabulary" -> instance.vocabulary.toList)
200200
val rMetadataJson: String = compact(render(rMetadata))
201-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
201+
// Note that we should write single file. If there are more than one row
202+
// it produces more partitions.
203+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
202204

203205
instance.pipeline.save(pipelinePath)
204206
}
@@ -211,7 +213,8 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
211213
val rMetadataPath = new Path(path, "rMetadata").toString
212214
val pipelinePath = new Path(path, "pipeline").toString
213215

214-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
216+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
217+
.first().getString(0)
215218
val rMetadata = parse(rMetadataStr)
216219
val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
217220
val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]

mllib/src/main/scala/org/apache/spark/ml/r/LinearRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ private[r] object LinearRegressionWrapper
127127
val rMetadata = ("class" -> instance.getClass.getName) ~
128128
("features" -> instance.features.toImmutableArraySeq)
129129
val rMetadataJson: String = compact(render(rMetadata))
130-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
130+
// Note that we should write single file. If there are more than one row
131+
// it produces more partitions.
132+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
131133

132134
instance.pipeline.save(pipelinePath)
133135
}
@@ -140,7 +142,8 @@ private[r] object LinearRegressionWrapper
140142
val rMetadataPath = new Path(path, "rMetadata").toString
141143
val pipelinePath = new Path(path, "pipeline").toString
142144

143-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
145+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
146+
.first().getString(0)
144147
val rMetadata = parse(rMetadataStr)
145148
val features = (rMetadata \ "features").extract[Array[String]]
146149

mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ private[r] object LinearSVCWrapper
137137
("features" -> instance.features.toImmutableArraySeq) ~
138138
("labels" -> instance.labels.toImmutableArraySeq)
139139
val rMetadataJson: String = compact(render(rMetadata))
140-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
140+
// Note that we should write single file. If there are more than one row
141+
// it produces more partitions.
142+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
141143

142144
instance.pipeline.save(pipelinePath)
143145
}
@@ -150,7 +152,8 @@ private[r] object LinearSVCWrapper
150152
val rMetadataPath = new Path(path, "rMetadata").toString
151153
val pipelinePath = new Path(path, "pipeline").toString
152154

153-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
155+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
156+
.first().getString(0)
154157
val rMetadata = parse(rMetadataStr)
155158
val features = (rMetadata \ "features").extract[Array[String]]
156159
val labels = (rMetadata \ "labels").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ private[r] object LogisticRegressionWrapper
192192
("features" -> instance.features.toImmutableArraySeq) ~
193193
("labels" -> instance.labels.toImmutableArraySeq)
194194
val rMetadataJson: String = compact(render(rMetadata))
195-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
195+
// Note that we should write single file. If there are more than one row
196+
// it produces more partitions.
197+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
196198

197199
instance.pipeline.save(pipelinePath)
198200
}
@@ -205,7 +207,8 @@ private[r] object LogisticRegressionWrapper
205207
val rMetadataPath = new Path(path, "rMetadata").toString
206208
val pipelinePath = new Path(path, "pipeline").toString
207209

208-
val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
210+
val rMetadataStr = sparkSession.read.text(rMetadataPath)
211+
.first().getString(0)
209212
val rMetadata = parse(rMetadataStr)
210213
val features = (rMetadata \ "features").extract[Array[String]]
211214
val labels = (rMetadata \ "labels").extract[Array[String]]

mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,9 @@ private[r] object MultilayerPerceptronClassifierWrapper
142142

143143
val rMetadata = "class" -> instance.getClass.getName
144144
val rMetadataJson: String = compact(render(rMetadata))
145-
sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
145+
// Note that we should write single file. If there are more than one row
146+
// it produces more partitions.
147+
sparkSession.createDataFrame(Seq(Tuple1(rMetadataJson))).write.text(rMetadataPath)
146148

147149
instance.pipeline.save(pipelinePath)
148150
}

0 commit comments

Comments
 (0)