Skip to content

Commit

Permalink
[SEDONA-439] Add RS_Union_Aggr (#1140)
Browse files Browse the repository at this point in the history
  • Loading branch information
furqaankhan authored Dec 8, 2023
1 parent fa76ed5 commit 691a1fa
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sedona.common.raster;

import org.locationtech.jts.geom.Geometry;
Expand Down
42 changes: 42 additions & 0 deletions docs/api/sql/Raster-aggregate-function.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
## RS_Union_Aggr

Introduction: Returns a raster containing bands by specified indexes from all rasters in the provided column. Extracts the first bands from each raster and combines them into the output raster based on the input index values.

!!!Note
RS_Union_Aggr can take multiple banded rasters as input but it would only extract the first band to the resulting raster. RS_Union_Aggr expects the following input, if not satisfied then will throw an IllegalArgumentException:

- Indexes to be in an arithmetic sequence without any gaps.
- Indexes to be unique and not repeated.
- Rasters should be of the same shape.

Format: `RS_Union_Aggr(A: rasterColumn, B: indexColumn)`

Since: `v1.5.1`

Spark SQL Example:

Contents of `raster_table`.

```
+------------------------------+-----+
| raster|index|
+------------------------------+-----+
|GridCoverage2D["geotiff_cov...| 1|
|GridCoverage2D["geotiff_cov...| 2|
|GridCoverage2D["geotiff_cov...| 3|
|GridCoverage2D["geotiff_cov...| 4|
|GridCoverage2D["geotiff_cov...| 5|
+------------------------------+-----+
```

```
SELECT RS_Union_Aggr(raster, index) FROM raster_table
```

Output:

This output raster contains the first band of each raster in the `raster_table` at specified index.

```
GridCoverage2D["geotiff_coverage", GeneralEnvel...
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ nav:
- SedonaKepler: api/sql/Visualization_SedonaKepler.md
- Raster data:
- Raster loader: api/sql/Raster-loader.md
- Raster aggregates: api/sql/Raster-aggregate-function.md
- Raster writer: api/sql/Raster-writer.md
- Raster operators: api/sql/Raster-operators.md
- Raster map algebra: api/sql/Raster-map-algebra.md
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.sedona_sql.expressions.collect.{ST_Collect}
import org.apache.spark.sql.sedona_sql.expressions.collect.ST_Collect
import org.apache.spark.sql.sedona_sql.expressions.raster._
import org.apache.spark.sql.sedona_sql.expressions._
import org.geotools.coverage.grid.GridCoverage2D
import org.locationtech.jts.geom.Geometry
import org.locationtech.jts.operation.buffer.BufferParameters

import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

object Catalog {
Expand Down Expand Up @@ -261,6 +263,8 @@ object Catalog {
function[RS_NetCDFInfo]()
)

val rasterAggregateExpression: Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] = new RS_Union_Aggr

val aggregateExpressions: Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq(
new ST_Union_Aggr,
new ST_Envelope_Aggr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object UdfRegistrator {
}
Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, functions.udaf(f))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.udf.register(f.getClass.getSimpleName, f)) // SPARK2 anchor
sparkSession.udf.register(Catalog.rasterAggregateExpression.getClass.getSimpleName, functions.udaf(Catalog.rasterAggregateExpression))
}

def dropAll(sparkSession: SparkSession): Unit = {
Expand All @@ -45,5 +46,6 @@ Catalog.aggregateExpressions.foreach(f => sparkSession.udf.register(f.getClass.g
}
Catalog.aggregateExpressions.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK3 anchor
//Catalog.aggregateExpressions_UDAF.foreach(f => sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(f.getClass.getSimpleName))) // SPARK2 anchor
sparkSession.sessionState.functionRegistry.dropFunction(FunctionIdentifier(Catalog.rasterAggregateExpression.getClass.getSimpleName))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.sedona_sql.expressions.raster

import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.geotools.coverage.GridSampleDimension
import org.geotools.coverage.grid.GridCoverage2D

import java.awt.image.WritableRaster
import javax.media.jai.RasterFactory
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

case class BandData(var bandInt: Array[Int], var bandDouble: Array[Double], var index: Int, var isIntegral: Boolean)

/**
* Return a raster containing bands at given indexes from all rasters in a given column
*/
class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] {

var width: Int = -1

var height: Int = -1

var referenceRaster: GridCoverage2D = _

var gridSampleDimension: mutable.Map[Int, GridSampleDimension] = new mutable.HashMap()

def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()

/**
* Valid raster shape to be the same in the given column
*/
def checkRasterShape(raster: GridCoverage2D): Boolean = {
// first iteration
if (width == -1 && height == -1) {
width = RasterAccessors.getWidth(raster)
height = RasterAccessors.getHeight(raster)
referenceRaster = raster
true
} else {
val widthNewRaster = RasterAccessors.getWidth(raster)
val heightNewRaster = RasterAccessors.getHeight(raster)

width == widthNewRaster && height == heightNewRaster
}
}

def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val raster = input._1
if (!checkRasterShape(raster)) {
throw new IllegalArgumentException("Rasters provides should be of the same shape.")
}
if (gridSampleDimension.contains(input._2)) {
throw new IllegalArgumentException("Indexes shouldn't be repeated. Index should be in an arithmetic sequence.")
}

val rasterData = RasterUtils.getRaster(raster.getRenderedImage)
val isIntegral = RasterUtils.isDataTypeIntegral(rasterData.getDataBuffer.getDataType)

val bandData = if (isIntegral) {
val band = rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Int]])
BandData(band, null, input._2, isIntegral)
} else {
val band = rasterData.getSamples(0, 0, width, height, 0, null.asInstanceOf[Array[Double]])
BandData(null, band, input._2, isIntegral)
}
gridSampleDimension = gridSampleDimension + (input._2 -> raster.getSampleDimension(0))

buffer += bandData
}

def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
ArrayBuffer.concat(buffer1, buffer2)
}

def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
val numBands = sortedMerged.length
val rasterData = RasterUtils.getRaster(referenceRaster.getRenderedImage)
val dataTypeCode = rasterData.getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, numBands, null)
val gridSampleDimensions: Array[GridSampleDimension] = new Array[GridSampleDimension](numBands)
var indexCheck = 1

for (bandData: BandData <- sortedMerged) {
if (bandData.index != indexCheck) {
throw new IllegalArgumentException("Indexes should be in a valid arithmetic sequence.")
}
indexCheck += 1
gridSampleDimensions(bandData.index - 1) = gridSampleDimension(bandData.index)
if(RasterUtils.isDataTypeIntegral(dataTypeCode))
resultRaster.setSamples(0, 0, width, height, (bandData.index - 1), bandData.bandInt)
else
resultRaster.setSamples(0, 0, width, height, bandData.index - 1, bandData.bandDouble)

}
val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true)
}

val serde = ExpressionEncoder[GridCoverage2D]

val bufferSerde = ExpressionEncoder[ArrayBuffer[BandData]]

def outputEncoder: ExpressionEncoder[GridCoverage2D] = serde

def bufferEncoder: Encoder[ArrayBuffer[BandData]] = bufferSerde
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.sedona.sql

import org.apache.sedona.common.raster.MapAlgebra
import org.apache.sedona.common.utils.RasterUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.functions.{col, collect_list, expr}
import org.apache.spark.sql.functions.{col, collect_list, expr, row_number}
import org.geotools.coverage.grid.GridCoverage2D
import org.junit.Assert.{assertEquals, assertNull, assertTrue}
import org.locationtech.jts.geom.{Coordinate, Geometry, Point}
import org.locationtech.jts.geom.{Coordinate, Geometry}
import org.scalatest.{BeforeAndAfter, GivenWhenThen}

import java.awt.image.DataBuffer
Expand Down Expand Up @@ -935,6 +936,28 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertEquals(expected, actual)
}

it("Passed RS_Union_Aggr") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
.union(sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
.union(sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")))
.withColumn("raster", expr("RS_FromGeoTiff(content) as raster"))
.withColumn("index", row_number().over(Window.orderBy("raster")))
.selectExpr("raster", "index")

val dfTest = sparkSession.read.format("binaryFile").load(resourceFolder + "raster/test1.tiff")
.selectExpr("RS_FromGeoTiff(content) as raster")

df = df.selectExpr("RS_Union_aggr(raster, index) as rasters")

val actualBands = df.selectExpr("RS_NumBands(rasters)").first().get(0)
val expectedBands = 3
assertEquals(expectedBands, actualBands)

val actualMetadata = df.selectExpr("RS_Metadata(rasters)").first().getSeq(0).slice(0, 9)
val expectedMetadata = dfTest.selectExpr("RS_Metadata(raster)").first().getSeq(0).slice(0, 9)
assertTrue(expectedMetadata.equals(actualMetadata))
}

it("Passed RS_ZonalStats") {
var df = sparkSession.read.format("binaryFile").load(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif")
df = df.selectExpr("RS_FromGeoTiff(content) as raster", "ST_GeomFromWKT('POLYGON ((236722 4204770, 243900 4204770, 243900 4197590, 221170 4197590, 236722 4204770))', 26918) as geom")
Expand Down

0 comments on commit 691a1fa

Please sign in to comment.