Skip to content

Commit f487515

Browse files
committed
[SPARK-51567][ML][CONNECT] Fix DistributedLDAModel.vocabSize
### What changes were proposed in this pull request? Fix `DistributedLDAModel.vocabSize` ### Why are the changes needed? ``` pyspark.errors.exceptions.connect.SparkException: [CONNECT_ML.ATTRIBUTE_NOT_ALLOWED] Generic Spark Connect ML error. vocabSize in org.apache.spark.ml.clustering.DistributedLDAModel is not allowed to be accessed. SQLSTATE: XX000 JVM stacktrace: org.apache.spark.sql.connect.ml.MLAttributeNotAllowedException at org.apache.spark.sql.connect.ml.MLUtils$.validate(MLUtils.scala:686) at org.apache.spark.sql.connect.ml.MLUtils$.invokeMethodAllowed(MLUtils.scala:691) at org.apache.spark.sql.connect.ml.AttributeHelper.$anonfun$getAttribute$1(MLHandler.scala:56) ``` ### Does this PR introduce _any_ user-facing change? yes, new api supported ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#50330 from zhengruifeng/ml_connect_lda_vocabSize. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 62c0669 commit f487515

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

python/pyspark/ml/tests/test_clustering.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def test_distributed_lda(self):
404404
self.assertNotIsInstance(model, LocalLDAModel)
405405
self.assertIsInstance(model, DistributedLDAModel)
406406
self.assertTrue(model.isDistributed())
407+
self.assertEqual(model.vocabSize(), 2)
407408

408409
dc = model.estimatedDocConcentration()
409410
self.assertTrue(np.allclose(dc.toArray(), [26.0, 26.0], atol=1e-4), dc)

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,8 @@ private[ml] object MLUtils {
621621
"isDistributed",
622622
"logLikelihood",
623623
"logPerplexity",
624-
"describeTopics")),
625-
(classOf[LocalLDAModel], Set("vocabSize")),
624+
"describeTopics",
625+
"vocabSize")),
626626
(
627627
classOf[DistributedLDAModel],
628628
Set("trainingLogLikelihood", "logPrior", "getCheckpointFiles", "toLocal")),

0 commit comments

Comments
 (0)