From 7910dc3175497202ea4b5b4d02fc9262c1f63163 Mon Sep 17 00:00:00 2001 From: Nilesh Gajwani Date: Thu, 8 Jun 2023 14:17:43 -0700 Subject: [PATCH] [SEDONA-239] Implement ST_NumPoints (#853) --- .../org/apache/sedona/common/Functions.java | 8 +++++++ .../apache/sedona/common/FunctionsTest.java | 16 ++++++++++++++ docs/api/flink/Function.md | 22 +++++++++++++++++++ docs/api/sql/Function.md | 20 +++++++++++++++++ .../java/org/apache/sedona/flink/Catalog.java | 3 ++- .../sedona/flink/expressions/Functions.java | 8 +++++++ .../org/apache/sedona/flink/FunctionTest.java | 8 +++++++ python/sedona/sql/st_functions.py | 10 +++++++++ python/tests/sql/test_dataframe_api.py | 1 + python/tests/sql/test_function.py | 5 +++++ .../org/apache/sedona/sql/UDF/Catalog.scala | 1 + .../sedona_sql/expressions/Functions.scala | 7 ++++++ .../sedona_sql/expressions/st_functions.scala | 4 ++++ .../sedona/sql/dataFrameAPITestScala.scala | 8 +++++++ .../apache/sedona/sql/functionTestScala.scala | 12 ++++++++++ 15 files changed, 132 insertions(+), 1 deletion(-) diff --git a/common/src/main/java/org/apache/sedona/common/Functions.java b/common/src/main/java/org/apache/sedona/common/Functions.java index ad7af11986..7fa3d802ba 100644 --- a/common/src/main/java/org/apache/sedona/common/Functions.java +++ b/common/src/main/java/org/apache/sedona/common/Functions.java @@ -845,6 +845,14 @@ private static Coordinate[] extractCoordinates(Geometry geometry) { return coordinates; } + public static int numPoints(Geometry geometry) throws Exception { + String geometryType = geometry.getGeometryType(); + if (!(Geometry.TYPENAME_LINESTRING.equalsIgnoreCase(geometryType))) { + throw new IllegalArgumentException("Unsupported geometry type: " + geometryType + ", only LineString geometry is supported."); + } + return geometry.getNumPoints(); + } + public static Geometry geometricMedian(Geometry geometry, double tolerance, int maxIter, boolean failIfNotConverged) throws Exception { String geometryType = geometry.getGeometryType(); if(!(Geometry.TYPENAME_POINT.equals(geometryType) || Geometry.TYPENAME_MULTIPOINT.equals(geometryType))) { diff --git a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java index 029c888a6f..93d20b0702 100644 --- a/common/src/test/java/org/apache/sedona/common/FunctionsTest.java +++ b/common/src/test/java/org/apache/sedona/common/FunctionsTest.java @@ -565,4 +565,20 @@ public void spheroidLength() { GeometryCollection geometryCollection = GEOMETRY_FACTORY.createGeometryCollection(new Geometry[] {point, line, multiLineString}); assertEquals(3.0056262514183864E7, Spheroid.length(geometryCollection), 0.1); } + + @Test + public void numPoints() throws Exception{ + LineString line = GEOMETRY_FACTORY.createLineString(coordArray(0, 1, 1, 0, 2, 0)); + int expected = 3; + int actual = Functions.numPoints(line); + assertEquals(expected, actual); + } + + @Test + public void numPointsUnsupported() throws Exception { + Polygon polygon = GEOMETRY_FACTORY.createPolygon(coordArray(0, 0, 0, 90, 0, 0)); + String expected = "Unsupported geometry type: " + "Polygon" + ", only LineString geometry is supported."; + Exception e = assertThrows(IllegalArgumentException.class, () -> Functions.numPoints(polygon)); + assertEquals(expected, e.getMessage()); + } } diff --git a/docs/api/flink/Function.md b/docs/api/flink/Function.md index e2f1d14241..0b79c9b9c1 100644 --- a/docs/api/flink/Function.md +++ b/docs/api/flink/Function.md @@ -695,6 +695,28 @@ SELECT ST_NumInteriorRings(ST_GeomFromText('POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), Output: `1` +## ST_NumPoints + +Introduction: Returns number of points in a LineString. + +!!!note + If any other geometry is provided as an argument, an IllegalArgumentException is thrown. + Example: + `SELECT ST_NumPoints(ST_GeomFromWKT('MULTIPOINT ((0 0), (1 1), (0 1), (2 2))'))` + + Output: `IllegalArgumentException: Unsupported geometry type: MultiPoint, only LineString geometry is supported.` + +Format: `ST_NumPoints(geom: geometry)` + +Since: `v1.4.1` + +Example: +```sql +SELECT ST_NumPoints(ST_GeomFromText('LINESTRING(1 2, 1 3)')) +``` + +Output: `2` + ## ST_PointN Introduction: Return the Nth point in a single linestring or circular linestring in the geometry. Negative values are counted backwards from the end of the LineString, so that -1 is the last point. Returns NULL if there is no linestring in the geometry. diff --git a/docs/api/sql/Function.md b/docs/api/sql/Function.md index 3eba9a036f..9935f0f069 100644 --- a/docs/api/sql/Function.md +++ b/docs/api/sql/Function.md @@ -1093,6 +1093,26 @@ SELECT ST_NumInteriorRings(ST_GeomFromText('POLYGON ((0 0, 0 5, 5 5, 5 0, 0 0), Output: `1` +## ST_NumPoints +Introduction: Returns number of points in a LineString + +!!!note + If any other geometry is provided as an argument, an IllegalArgumentException is thrown. + Example: + `SELECT ST_NumPoints(ST_GeomFromWKT('MULTIPOINT ((0 0), (1 1), (0 1), (2 2))'))` + + Output: `IllegalArgumentException: Unsupported geometry type: MultiPoint, only LineString geometry is supported.` +Format: `ST_NumPoints(geom: geometry)` + +Since: `v1.4.1` + +Spark SQL example: +```sql +SELECT ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)')) +``` + +Output: `3` + ## ST_PointN Introduction: Return the Nth point in a single linestring or circular linestring in the geometry. Negative values are counted backwards from the end of the LineString, so that -1 is the last point. Returns NULL if there is no linestring in the geometry. diff --git a/flink/src/main/java/org/apache/sedona/flink/Catalog.java b/flink/src/main/java/org/apache/sedona/flink/Catalog.java index 66e4bffa2b..884126c00e 100644 --- a/flink/src/main/java/org/apache/sedona/flink/Catalog.java +++ b/flink/src/main/java/org/apache/sedona/flink/Catalog.java @@ -94,7 +94,8 @@ public static UserDefinedFunction[] getFuncs() { new Functions.ST_LineFromMultiPoint(), new Functions.ST_Split(), new Functions.ST_S2CellIDs(), - new Functions.ST_GeometricMedian() + new Functions.ST_GeometricMedian(), + new Functions.ST_NumPoints() }; } diff --git a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java index 608a461bdc..7001345a4c 100644 --- a/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java +++ b/flink/src/main/java/org/apache/sedona/flink/expressions/Functions.java @@ -574,4 +574,12 @@ public Geometry eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.j } + public static class ST_NumPoints extends ScalarFunction { + @DataTypeHint(value = "Integer") + public int eval(@DataTypeHint(value = "RAW", bridgedTo = org.locationtech.jts.geom.Geometry.class) Object o) throws Exception { + Geometry geometry = (Geometry) o; + return org.apache.sedona.common.Functions.numPoints(geometry); + } + } + } diff --git a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java index 933e216fe1..eac04d2fde 100644 --- a/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java +++ b/flink/src/test/java/org/apache/sedona/flink/FunctionTest.java @@ -691,4 +691,12 @@ public void testGeometricMedianParamsFull() throws ParseException { 0, expected.compareTo(actual, COORDINATE_SEQUENCE_COMPARATOR)); } + @Test + public void testNumPoints() { + Integer expected = 3; + Table pointTable = tableEnv.sqlQuery("SELECT ST_NumPoints(ST_GeomFromWKT('LINESTRING(0 1, 1 0, 2 0)'))"); + Integer actual = (Integer) first(pointTable).getField(0); + assertEquals(expected, actual); + } + } diff --git a/python/sedona/sql/st_functions.py b/python/sedona/sql/st_functions.py index f3cea50b4b..d5c7602b1d 100644 --- a/python/sedona/sql/st_functions.py +++ b/python/sedona/sql/st_functions.py @@ -108,6 +108,7 @@ "ST_Z", "ST_ZMax", "ST_ZMin", + "ST_NumPoints" ] @@ -1231,3 +1232,12 @@ def ST_ZMin(geometry: ColumnOrName) -> Column: :rtype: Column """ return _call_st_function("ST_ZMin", geometry) + +def ST_NumPoints(geometry: ColumnOrName) -> Column: + """Return the number of points in a LineString + :param geometry: Geometry column to get number of points from. + :type geometry: ColumnOrName + :return: Number of points in a LineString as an integer column + :rtype: Column + """ + return _call_st_function("ST_NumPoints", geometry) diff --git a/python/tests/sql/test_dataframe_api.py b/python/tests/sql/test_dataframe_api.py index 6551d3bda9..1ece9f699f 100644 --- a/python/tests/sql/test_dataframe_api.py +++ b/python/tests/sql/test_dataframe_api.py @@ -135,6 +135,7 @@ (stf.ST_YMax, ("geom",), "triangle_geom", "", 1.0), (stf.ST_YMin, ("geom",), "triangle_geom", "", 0.0), (stf.ST_Z, ("b",), "two_points", "", 4.0), + (stf.ST_NumPoints, ("line",), "linestring_geom", "", 6), # predicates (stp.ST_Contains, ("geom", lambda: f.expr("ST_Point(0.5, 0.25)")), "triangle_geom", "", True), diff --git a/python/tests/sql/test_function.py b/python/tests/sql/test_function.py index 059fd30b6c..ba6657741f 100644 --- a/python/tests/sql/test_function.py +++ b/python/tests/sql/test_function.py @@ -1074,3 +1074,8 @@ def test_st_s2_cell_ids(self): # test null case cell_ids = self.spark.sql("select ST_S2CellIDs(null, 6)").take(1)[0][0] assert cell_ids is None + + def test_st_numPoints(self): + actual = self.spark.sql("SELECT ST_NumPoints(ST_GeomFromText('LINESTRING(0 1, 1 0, 2 0)'))").take(1)[0][0] + expected = 3 + assert expected == actual diff --git a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index d50af2ab4f..eede0a130b 100644 --- a/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/sql/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -147,6 +147,7 @@ object Catalog { function[ST_DistanceSpheroid](), function[ST_AreaSpheroid](), function[ST_LengthSpheroid](), + function[ST_NumPoints](), // Expression for rasters function[RS_NormalizedDifference](), function[RS_Mean](), diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala index 0a4c0dcb4e..2a6dfde3c6 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Functions.scala @@ -981,3 +981,10 @@ case class ST_LengthSpheroid(inputExpressions: Seq[Expression]) } } +case class ST_NumPoints(inputExpressions: Seq[Expression]) + extends InferredUnaryExpression(Functions.numPoints) with FoldableExpression { + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} + diff --git a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala index ad29b854ed..94e0874d0d 100644 --- a/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala +++ b/sql/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/st_functions.scala @@ -301,4 +301,8 @@ object st_functions extends DataFrameAPI { def ST_LengthSpheroid(a: Column): Column = wrapExpression[ST_LengthSpheroid](a) def ST_LengthSpheroid(a: String): Column = wrapExpression[ST_LengthSpheroid](a) + + def ST_NumPoints(geometry: Column): Column = wrapExpression[ST_NumPoints](geometry) + + def ST_NumPoints(geometry: String): Column = wrapExpression[ST_NumPoints](geometry) } diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala index 787a7300cd..e3eaf8ff41 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/dataFrameAPITestScala.scala @@ -949,5 +949,13 @@ class dataFrameAPITestScala extends TestBaseScala { val expectedResult = 10018754.171394622 assertEquals(expectedResult, actualResult, 0.1) } + + it("Passed ST_NumPoints") { + val lineDf = sparkSession.sql("SELECT ST_GeomFromWKT('LINESTRING (0 1, 1 0, 2 0)') AS geom") + val df = lineDf.select(ST_NumPoints("geom")) + val actualResult = df.take(1)(0).getInt(0) + val expectedResult = 3 + assert(actualResult == expectedResult) + } } } diff --git a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala index 0e598e6b5f..21e897a5ee 100644 --- a/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala +++ b/sql/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala @@ -1909,4 +1909,16 @@ class functionTestScala extends TestBaseScala with Matchers with GeometrySample assertEquals(expected, actual, 0.1) } } + + it("Should pass ST_NumPoints") { + val geomTestCases = Map( + ("'LINESTRING (0 1, 1 0, 2 0)'") -> "3" + ) + for (((geom), expectedResult) <- geomTestCases) { + val df = sparkSession.sql(s"SELECT ST_NumPoints(ST_GeomFromWKT($geom)), " + s"$expectedResult") + val actual = df.take(1)(0).get(0).asInstanceOf[Int] + val expected = df.take(1)(0).get(1).asInstanceOf[java.lang.Integer].intValue() + assertEquals(expected, actual) + } + } }