Skip to content

Commit 869bfc9

Browse files
Jing Chen Hesrowen
Jing Chen He
authored andcommitted
[SPARK-26315][PYSPARK] auto cast threshold from Integer to Float in approxSimilarityJoin of BucketedRandomProjectionLSHModel
## What changes were proposed in this pull request? If the input parameter 'threshold' to the function approxSimilarityJoin is not a float, we would get an exception. The fix is to convert the 'threshold' into a float before calling the java implementation method. ## How was this patch tested? Added a new test case. Without this fix, the test will throw an exception as reported in the JIRA. With the fix, the test passes. Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #23313 from jerryjch/SPARK-26315. Authored-by: Jing Chen He <jinghe@us.ibm.com> Signed-off-by: Sean Owen <sean.owen@databricks.com> (cherry picked from commit 860f449) Signed-off-by: Sean Owen <sean.owen@databricks.com>
1 parent 6019d9a commit 869bfc9

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

python/pyspark/ml/feature.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol")
193193
"datasetA" and "datasetB", and a column "distCol" is added to show the distance
194194
between each pair.
195195
"""
196+
threshold = TypeConverters.toFloat(threshold)
196197
return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
197198

198199

@@ -240,6 +241,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp
240241
| 3| 6| 2.23606797749979|
241242
+---+---+-----------------+
242243
...
244+
>>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(
245+
... col("datasetA.id").alias("idA"),
246+
... col("datasetB.id").alias("idB"),
247+
... col("EuclideanDistance")).show()
248+
+---+---+-----------------+
249+
|idA|idB|EuclideanDistance|
250+
+---+---+-----------------+
251+
| 3| 6| 2.23606797749979|
252+
+---+---+-----------------+
253+
...
243254
>>> brpPath = temp_path + "/brp"
244255
>>> brp.save(brpPath)
245256
>>> brp2 = BucketedRandomProjectionLSH.load(brpPath)

0 commit comments

Comments
 (0)