Skip to content

Commit 860f449

Browse files
Jing Chen Hesrowen
Jing Chen He
authored andcommitted
[SPARK-26315][PYSPARK] auto cast threshold from Integer to Float in approxSimilarityJoin of BucketedRandomProjectionLSHModel
## What changes were proposed in this pull request? If the input parameter 'threshold' to the function approxSimilarityJoin is not a float, we would get an exception. The fix is to convert the 'threshold' into a float before calling the java implementation method. ## How was this patch tested? Added a new test case. Without this fix, the test will throw an exception as reported in the JIRA. With the fix, the test passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23313 from jerryjch/SPARK-26315. Authored-by: Jing Chen He <jinghe@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent 9ccae0c commit 860f449

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

python/pyspark/ml/feature.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol")
192192
"datasetA" and "datasetB", and a column "distCol" is added to show the distance
193193
between each pair.
194194
"""
195+
threshold = TypeConverters.toFloat(threshold)
195196
return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
196197

197198

@@ -239,6 +240,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp
239240
| 3| 6| 2.23606797749979|
240241
+---+---+-----------------+
241242
...
243+
>>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(
244+
... col("datasetA.id").alias("idA"),
245+
... col("datasetB.id").alias("idB"),
246+
... col("EuclideanDistance")).show()
247+
+---+---+-----------------+
248+
|idA|idB|EuclideanDistance|
249+
+---+---+-----------------+
250+
| 3| 6| 2.23606797749979|
251+
+---+---+-----------------+
252+
...
242253
>>> brpPath = temp_path + "/brp"
243254
>>> brp.save(brpPath)
244255
>>> brp2 = BucketedRandomProjectionLSH.load(brpPath)

0 commit comments

Comments
 (0)