forked from tripl-ai/arc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
319 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
214 changes: 214 additions & 0 deletions
214
src/main/scala/ai/tripl/arc/transform/SimilarityJoinTransform.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
package ai.tripl.arc.transform | ||
|
||
import java.util.UUID | ||
import scala.collection.JavaConverters._ | ||
|
||
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.ml.Pipeline | ||
import org.apache.spark.ml.feature.{CountVectorizer, MinHashLSH, MinHashLSHModel, NGram, RegexTokenizer} | ||
|
||
import com.typesafe.config._ | ||
|
||
import ai.tripl.arc.api._ | ||
import ai.tripl.arc.api.API._ | ||
import ai.tripl.arc.config._ | ||
import ai.tripl.arc.config.Error._ | ||
import ai.tripl.arc.plugins.PipelineStagePlugin | ||
import ai.tripl.arc.util.Utils | ||
import ai.tripl.arc.util.DetailException | ||
|
||
class SimilarityJoinTransform extends PipelineStagePlugin { | ||
|
||
val version = Utils.getFrameworkVersion | ||
|
||
def instantiate(index: Int, config: com.typesafe.config.Config)(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Either[List[ai.tripl.arc.config.Error.StageError], PipelineStage] = { | ||
import ai.tripl.arc.config.ConfigReader._ | ||
import ai.tripl.arc.config.ConfigUtils._ | ||
implicit val c = config | ||
|
||
val expectedKeys = "type" :: "name" :: "description" :: "environments" :: "leftView" :: "leftFields" :: "rightView" :: "rightFields" :: "outputView" :: "shingleLength" :: "numHashTables" :: "threshold" :: "numPartitions" :: "partitionBy" :: "params" :: Nil | ||
val name = getValue[String]("name") | ||
val description = getOptionalValue[String]("description") | ||
val leftView = getValue[String]("leftView") | ||
val leftFields = getValue[StringList]("leftFields") | ||
val rightView = getValue[String]("rightView") | ||
val rightFields = getValue[StringList]("rightFields") | ||
val outputView = getValue[String]("outputView") | ||
val persist = getValue[java.lang.Boolean]("persist", default = Some(false)) | ||
val shingleLength = getValue[Int]("shingleLength", default = Some(3)) | ||
val numHashTables = getValue[Int]("numHashTables", default = Some(5)) | ||
val threshold = getValue[java.lang.Double]("threshold", default = Some(0.8)) | ||
val caseSensitive = getValue[java.lang.Boolean]("caseSensitive", default = Some(false)) | ||
val partitionBy = getValue[StringList]("partitionBy", default = Some(Nil)) | ||
val numPartitions = getOptionalValue[Int]("numPartitions") | ||
val params = readMap("params", c) | ||
val invalidKeys = checkValidKeys(c)(expectedKeys) | ||
|
||
(name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, invalidKeys) match { | ||
case (Right(name), Right(description), Right(leftView), Right(leftFields), Right(rightView), Right(rightFields), Right(outputView), Right(persist), Right(shingleLength), Right(numHashTables), Right(threshold), Right(caseSensitive), Right(partitionBy), Right(numPartitions), Right(invalidKeys)) => | ||
|
||
val stage = SimilarityJoinTransformStage( | ||
plugin=this, | ||
name=name, | ||
description=description, | ||
leftView=leftView, | ||
leftFields=leftFields, | ||
rightView=rightView, | ||
rightFields=rightFields, | ||
outputView=outputView, | ||
persist=persist, | ||
shingleLength=shingleLength, | ||
numHashTables=numHashTables, | ||
threshold=threshold, | ||
caseSensitive=caseSensitive, | ||
partitionBy=partitionBy, | ||
numPartitions=numPartitions, | ||
params=params | ||
) | ||
|
||
stage.stageDetail.put("leftView", leftView) | ||
stage.stageDetail.put("leftFields", leftFields.asJava) | ||
stage.stageDetail.put("rightView", rightView) | ||
stage.stageDetail.put("rightFields", rightFields.asJava) | ||
stage.stageDetail.put("outputView", outputView) | ||
stage.stageDetail.put("persist", java.lang.Boolean.valueOf(persist)) | ||
stage.stageDetail.put("shingleLength", java.lang.Integer.valueOf(shingleLength)) | ||
stage.stageDetail.put("numHashTables", java.lang.Integer.valueOf(numHashTables)) | ||
stage.stageDetail.put("threshold", java.lang.Double.valueOf(threshold)) | ||
stage.stageDetail.put("caseSensitive", java.lang.Boolean.valueOf(caseSensitive)) | ||
stage.stageDetail.put("partitionBy", partitionBy.asJava) | ||
|
||
Right(stage) | ||
case _ => | ||
val allErrors: Errors = List(name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, invalidKeys).collect{ case Left(errs) => errs }.flatten | ||
val stageName = stringOrDefault(name, "unnamed stage") | ||
val err = StageError(index, stageName, c.origin.lineNumber, allErrors) | ||
Left(err :: Nil) | ||
} | ||
} | ||
} | ||
|
||
case class SimilarityJoinTransformStage( | ||
plugin: SimilarityJoinTransform, | ||
name: String, | ||
description: Option[String], | ||
leftView: String, | ||
leftFields: List[String], | ||
rightView: String, | ||
rightFields: List[String], | ||
outputView: String, | ||
persist: Boolean, | ||
shingleLength: Int, | ||
numHashTables: Int, | ||
threshold: Double, | ||
caseSensitive: Boolean, | ||
partitionBy: List[String], | ||
numPartitions: Option[Int], | ||
params: Map[String, String] | ||
) extends PipelineStage { | ||
|
||
override def execute()(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = { | ||
SimilarityJoinTransformStage.execute(this) | ||
} | ||
} | ||
|
||
object SimilarityJoinTransformStage { | ||
|
||
def execute(stage: SimilarityJoinTransformStage)(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = { | ||
|
||
val uuid = UUID.randomUUID.toString | ||
|
||
// split input string into individual characters | ||
val regexTokenizer = { new RegexTokenizer() | ||
.setInputCol(uuid) | ||
.setPattern("") | ||
.setMinTokenLength(1) | ||
.setToLowercase(!stage.caseSensitive) | ||
} | ||
|
||
// produce ngrams to group the characters | ||
val nGram = { new NGram() | ||
.setInputCol(regexTokenizer.getOutputCol) | ||
.setN(stage.shingleLength) | ||
} | ||
|
||
// convert to vector | ||
val countVectorizer = { new CountVectorizer() | ||
.setInputCol(nGram.getOutputCol) | ||
} | ||
|
||
// build locality-sensitive hashing model | ||
val minHashLSH = { new MinHashLSH() | ||
.setInputCol(countVectorizer.getOutputCol) | ||
.setNumHashTables(stage.numHashTables) | ||
.setOutputCol("lsh") | ||
} | ||
|
||
val pipeline = new Pipeline().setStages(Array(regexTokenizer, nGram, countVectorizer, minHashLSH)) | ||
|
||
val transformedDF = try { | ||
|
||
val leftView = spark.table(stage.leftView) | ||
val rightView = spark.table(stage.rightView) | ||
|
||
// create a string space concatenated field | ||
val inputLeftView = leftView.select( | ||
col("*"), trim(concat(stage.leftFields.map{ field => when(col(field).isNotNull, concat(col(field).cast(StringType), lit(" "))).otherwise("") }:_*)).alias(uuid) | ||
) | ||
val inputRightView = rightView.select( | ||
col("*"), trim(concat(stage.rightFields.map{ field => when(col(field).isNotNull, concat(col(field).cast(StringType), lit(" "))).otherwise("") }:_*)).alias(uuid) | ||
) | ||
|
||
val pipelineModel = pipeline.fit(inputLeftView) | ||
|
||
val datasetA = pipelineModel.transform(inputLeftView) | ||
val datasetB = pipelineModel.transform(inputRightView) | ||
|
||
val leftOutputColumns = leftView.columns.map{columnName => col(s"datasetA.${columnName}")} | ||
val rightOutputColumns = rightView.columns.map{columnName => col(s"datasetB.${columnName}")} | ||
|
||
pipelineModel.stages(3).asInstanceOf[MinHashLSHModel] | ||
.approxSimilarityJoin(datasetA, datasetB, (1.0-stage.threshold)) | ||
.select((leftOutputColumns ++ rightOutputColumns ++ List((lit(1.0)-col("distCol")).alias("similarity"))):_*) | ||
|
||
} catch { | ||
case e: Exception => throw new Exception(e) with DetailException { | ||
override val detail = stage.stageDetail | ||
} | ||
} | ||
|
||
// repartition to distribute rows evenly | ||
val repartitionedDF = stage.partitionBy match { | ||
case Nil => { | ||
stage.numPartitions match { | ||
case Some(numPartitions) => transformedDF.repartition(numPartitions) | ||
case None => transformedDF | ||
} | ||
} | ||
case partitionBy => { | ||
// create a column array for repartitioning | ||
val partitionCols = partitionBy.map(col => transformedDF(col)) | ||
stage.numPartitions match { | ||
case Some(numPartitions) => transformedDF.repartition(numPartitions, partitionCols:_*) | ||
case None => transformedDF.repartition(partitionCols:_*) | ||
} | ||
} | ||
} | ||
if (arcContext.immutableViews) repartitionedDF.createTempView(stage.outputView) else repartitionedDF.createOrReplaceTempView(stage.outputView) | ||
|
||
if (!repartitionedDF.isStreaming) { | ||
// add partition and predicate pushdown detail to logs | ||
stage.stageDetail.put("outputColumns", java.lang.Integer.valueOf(repartitionedDF.schema.length)) | ||
stage.stageDetail.put("numPartitions", java.lang.Integer.valueOf(repartitionedDF.rdd.partitions.length)) | ||
|
||
if (stage.persist) { | ||
repartitionedDF.persist(arcContext.storageLevel) | ||
stage.stageDetail.put("records", java.lang.Long.valueOf(repartitionedDF.count)) | ||
} | ||
} | ||
|
||
Option(repartitionedDF) | ||
} | ||
} |
103 changes: 103 additions & 0 deletions
103
src/test/scala/ai/tripl/arc/transform/SimilarityJoinTransformSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
package ai.tripl.arc.plugins | ||
|
||
import ai.tripl.arc.ARC | ||
import ai.tripl.arc.api.API._ | ||
import ai.tripl.arc.config.ArcPipeline | ||
import ai.tripl.arc.config.Error._ | ||
import ai.tripl.arc.util.TestUtils | ||
|
||
import org.apache.spark.sql.SparkSession | ||
import org.scalatest.{BeforeAndAfter, FunSuite} | ||
import ai.tripl.arc.extract.ParquetExtract | ||
|
||
class SimilarityJoinTransformSuite extends FunSuite with BeforeAndAfter { | ||
|
||
var session: SparkSession = _ | ||
val leftView = "leftView" | ||
val rightView = "rightView" | ||
val outputView = "outputView" | ||
|
||
before { | ||
implicit val spark = SparkSession | ||
.builder() | ||
.master("local[*]") | ||
.config("spark.ui.port", "9999") | ||
.appName("Spark ETL Test") | ||
.getOrCreate() | ||
spark.sparkContext.setLogLevel("INFO") | ||
implicit val logger = TestUtils.getLogger() | ||
|
||
// set for deterministic timezone | ||
spark.conf.set("spark.sql.session.timeZone", "UTC") | ||
|
||
session = spark | ||
import spark.implicits._ | ||
} | ||
|
||
after { | ||
session.stop() | ||
} | ||
|
||
test("SimilarityJoinTransformSuite") { | ||
implicit val spark = session | ||
import spark.implicits._ | ||
implicit val logger = TestUtils.getLogger() | ||
implicit val arcContext = TestUtils.getARCContext(isStreaming=false) | ||
|
||
val leftDF = Seq( | ||
("GANSW705647478",Option("UNIT 3"),59,"INVERNESS","AVENUE","PENSHURST",2222,"NSW"), | ||
("GANSW704384670",Option("UNIT 10"),30,"ARCHER","STREET","CHATSWOOD",2067,"NSW"), | ||
("GANSW716607633",Option("UNIT 1"),95,"GARDINER","ROAD","ORANGE",2800,"NSW"), | ||
("GANSW704527834",None,26,"LINKS","AVENUE","CRONULLA",2230,"NSW"), | ||
("GANSW704579026",None,13,"VALLEY","ROAD","DENHAMS BEACH",2536,"NSW"), | ||
("GANSW712760955",Option("UNIT 17"),39,"MACARTHUR","STREET","GRIFFITH",2680,"NSW"), | ||
("GANSW704356027",None,66,"MILLERS","ROAD","CATTAI",2756,"NSW"), | ||
("GANSW705978672",None,74,"CANYON","DRIVE","STANHOPE GARDENS",2768,"NSW"), | ||
("GANSW717662718",Option("UNIT 744"),9,"ROTHSCHILD","AVENUE","ROSEBERY",2018,"NSW"), | ||
("GANSW710590397",Option("UNIT 303"),2,"DIND","STREET","MILSONS POINT",2061,"NSW") | ||
).toDF("gnaf_pid", "flat_number", "number_first", "street_name", "street_type", "locality_name", "postcode", "state") | ||
leftDF.createOrReplaceTempView(leftView) | ||
|
||
val rightDF = Seq( | ||
(0L,"U3 59 INVERNESS AVENUE","NSW 2222 PENSHURST"), | ||
(1L,"74 CANYON DR", "NSW 2768 STANEHOPE GDNS."), | ||
).toDF("id", "street", "state_postcode_suburb") | ||
rightDF.createOrReplaceTempView(rightView) | ||
|
||
val conf = s"""{ | ||
"stages": [ | ||
{ | ||
"type": "SimilarityJoinTransform", | ||
"name": "test", | ||
"description": "test", | ||
"environments": [ | ||
"production", | ||
"test" | ||
], | ||
"leftView": "${leftView}", | ||
"leftFields": ["flat_number", "number_first", "street_name", "street_type", "locality_name", "postcode", "state"], | ||
"rightView": "${rightView}", | ||
"rightFields": ["street", "state_postcode_suburb"], | ||
"outputView": "${outputView}", | ||
"threshold": 0.75, | ||
"shingleLength": 3, | ||
"numHashTables": 10 | ||
} | ||
] | ||
}""" | ||
|
||
val pipelineEither = ArcPipeline.parseConfig(Left(conf), arcContext) | ||
|
||
pipelineEither match { | ||
case Left(_) => { | ||
println(pipelineEither) | ||
assert(false) | ||
} | ||
case Right((pipeline, _)) => { | ||
val df = ARC.run(pipeline)(spark, logger, arcContext) | ||
assert(df.get.count == 1) | ||
} | ||
} | ||
|
||
} | ||
} |