Skip to content

Commit

Permalink
Add SimilarityJoinTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
seddonm1 committed Sep 21, 2019
1 parent 8c6ba4c commit 5949516
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# 2.1.0

- add `SimilarityJoinTransform` a stage which performs a [fuzzy match](https://en.wikipedia.org/wiki/Approximate_string_matching) and can be used for dataset deduplication or approximate joins.
- add missing types `BooleanList`, `Double`, `DoubleList`, `LongList` to config reader.

**BREAKING**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ ai.tripl.arc.transform.JSONTransform
ai.tripl.arc.transform.MetadataFilterTransform
ai.tripl.arc.transform.MetadataTransform
ai.tripl.arc.transform.MLTransform
ai.tripl.arc.transform.SimilarityJoinTransform
ai.tripl.arc.transform.SQLTransform
ai.tripl.arc.transform.TensorFlowServingTransform
ai.tripl.arc.transform.TypingTransform
ai.tripl.arc.validate.EqualityValidate
ai.tripl.arc.validate.SQLValidate

214 changes: 214 additions & 0 deletions src/main/scala/ai/tripl/arc/transform/SimilarityJoinTransform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package ai.tripl.arc.transform

import java.util.UUID
import scala.collection.JavaConverters._

import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.sql.functions._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{CountVectorizer, MinHashLSH, MinHashLSHModel, NGram, RegexTokenizer}

import com.typesafe.config._

import ai.tripl.arc.api._
import ai.tripl.arc.api.API._
import ai.tripl.arc.config._
import ai.tripl.arc.config.Error._
import ai.tripl.arc.plugins.PipelineStagePlugin
import ai.tripl.arc.util.Utils
import ai.tripl.arc.util.DetailException

class SimilarityJoinTransform extends PipelineStagePlugin {

val version = Utils.getFrameworkVersion

def instantiate(index: Int, config: com.typesafe.config.Config)(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Either[List[ai.tripl.arc.config.Error.StageError], PipelineStage] = {
import ai.tripl.arc.config.ConfigReader._
import ai.tripl.arc.config.ConfigUtils._
implicit val c = config

val expectedKeys = "type" :: "name" :: "description" :: "environments" :: "leftView" :: "leftFields" :: "rightView" :: "rightFields" :: "outputView" :: "shingleLength" :: "numHashTables" :: "threshold" :: "numPartitions" :: "partitionBy" :: "params" :: Nil
val name = getValue[String]("name")
val description = getOptionalValue[String]("description")
val leftView = getValue[String]("leftView")
val leftFields = getValue[StringList]("leftFields")
val rightView = getValue[String]("rightView")
val rightFields = getValue[StringList]("rightFields")
val outputView = getValue[String]("outputView")
val persist = getValue[java.lang.Boolean]("persist", default = Some(false))
val shingleLength = getValue[Int]("shingleLength", default = Some(3))
val numHashTables = getValue[Int]("numHashTables", default = Some(5))
val threshold = getValue[java.lang.Double]("threshold", default = Some(0.8))
val caseSensitive = getValue[java.lang.Boolean]("caseSensitive", default = Some(false))
val partitionBy = getValue[StringList]("partitionBy", default = Some(Nil))
val numPartitions = getOptionalValue[Int]("numPartitions")
val params = readMap("params", c)
val invalidKeys = checkValidKeys(c)(expectedKeys)

(name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, invalidKeys) match {
case (Right(name), Right(description), Right(leftView), Right(leftFields), Right(rightView), Right(rightFields), Right(outputView), Right(persist), Right(shingleLength), Right(numHashTables), Right(threshold), Right(caseSensitive), Right(partitionBy), Right(numPartitions), Right(invalidKeys)) =>

val stage = SimilarityJoinTransformStage(
plugin=this,
name=name,
description=description,
leftView=leftView,
leftFields=leftFields,
rightView=rightView,
rightFields=rightFields,
outputView=outputView,
persist=persist,
shingleLength=shingleLength,
numHashTables=numHashTables,
threshold=threshold,
caseSensitive=caseSensitive,
partitionBy=partitionBy,
numPartitions=numPartitions,
params=params
)

stage.stageDetail.put("leftView", leftView)
stage.stageDetail.put("leftFields", leftFields.asJava)
stage.stageDetail.put("rightView", rightView)
stage.stageDetail.put("rightFields", rightFields.asJava)
stage.stageDetail.put("outputView", outputView)
stage.stageDetail.put("persist", java.lang.Boolean.valueOf(persist))
stage.stageDetail.put("shingleLength", java.lang.Integer.valueOf(shingleLength))
stage.stageDetail.put("numHashTables", java.lang.Integer.valueOf(numHashTables))
stage.stageDetail.put("threshold", java.lang.Double.valueOf(threshold))
stage.stageDetail.put("caseSensitive", java.lang.Boolean.valueOf(caseSensitive))
stage.stageDetail.put("partitionBy", partitionBy.asJava)

Right(stage)
case _ =>
val allErrors: Errors = List(name, description, leftView, leftFields, rightView, rightFields, outputView, persist, shingleLength, numHashTables, threshold, caseSensitive, partitionBy, numPartitions, invalidKeys).collect{ case Left(errs) => errs }.flatten
val stageName = stringOrDefault(name, "unnamed stage")
val err = StageError(index, stageName, c.origin.lineNumber, allErrors)
Left(err :: Nil)
}
}
}

case class SimilarityJoinTransformStage(
plugin: SimilarityJoinTransform,
name: String,
description: Option[String],
leftView: String,
leftFields: List[String],
rightView: String,
rightFields: List[String],
outputView: String,
persist: Boolean,
shingleLength: Int,
numHashTables: Int,
threshold: Double,
caseSensitive: Boolean,
partitionBy: List[String],
numPartitions: Option[Int],
params: Map[String, String]
) extends PipelineStage {

override def execute()(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = {
SimilarityJoinTransformStage.execute(this)
}
}

object SimilarityJoinTransformStage {

def execute(stage: SimilarityJoinTransformStage)(implicit spark: SparkSession, logger: ai.tripl.arc.util.log.logger.Logger, arcContext: ARCContext): Option[DataFrame] = {

val uuid = UUID.randomUUID.toString

// split input string into individual characters
val regexTokenizer = { new RegexTokenizer()
.setInputCol(uuid)
.setPattern("")
.setMinTokenLength(1)
.setToLowercase(!stage.caseSensitive)
}

// produce ngrams to group the characters
val nGram = { new NGram()
.setInputCol(regexTokenizer.getOutputCol)
.setN(stage.shingleLength)
}

// convert to vector
val countVectorizer = { new CountVectorizer()
.setInputCol(nGram.getOutputCol)
}

// build locality-sensitive hashing model
val minHashLSH = { new MinHashLSH()
.setInputCol(countVectorizer.getOutputCol)
.setNumHashTables(stage.numHashTables)
.setOutputCol("lsh")
}

val pipeline = new Pipeline().setStages(Array(regexTokenizer, nGram, countVectorizer, minHashLSH))

val transformedDF = try {

val leftView = spark.table(stage.leftView)
val rightView = spark.table(stage.rightView)

// create a string space concatenated field
val inputLeftView = leftView.select(
col("*"), trim(concat(stage.leftFields.map{ field => when(col(field).isNotNull, concat(col(field).cast(StringType), lit(" "))).otherwise("") }:_*)).alias(uuid)
)
val inputRightView = rightView.select(
col("*"), trim(concat(stage.rightFields.map{ field => when(col(field).isNotNull, concat(col(field).cast(StringType), lit(" "))).otherwise("") }:_*)).alias(uuid)
)

val pipelineModel = pipeline.fit(inputLeftView)

val datasetA = pipelineModel.transform(inputLeftView)
val datasetB = pipelineModel.transform(inputRightView)

val leftOutputColumns = leftView.columns.map{columnName => col(s"datasetA.${columnName}")}
val rightOutputColumns = rightView.columns.map{columnName => col(s"datasetB.${columnName}")}

pipelineModel.stages(3).asInstanceOf[MinHashLSHModel]
.approxSimilarityJoin(datasetA, datasetB, (1.0-stage.threshold))
.select((leftOutputColumns ++ rightOutputColumns ++ List((lit(1.0)-col("distCol")).alias("similarity"))):_*)

} catch {
case e: Exception => throw new Exception(e) with DetailException {
override val detail = stage.stageDetail
}
}

// repartition to distribute rows evenly
val repartitionedDF = stage.partitionBy match {
case Nil => {
stage.numPartitions match {
case Some(numPartitions) => transformedDF.repartition(numPartitions)
case None => transformedDF
}
}
case partitionBy => {
// create a column array for repartitioning
val partitionCols = partitionBy.map(col => transformedDF(col))
stage.numPartitions match {
case Some(numPartitions) => transformedDF.repartition(numPartitions, partitionCols:_*)
case None => transformedDF.repartition(partitionCols:_*)
}
}
}
if (arcContext.immutableViews) repartitionedDF.createTempView(stage.outputView) else repartitionedDF.createOrReplaceTempView(stage.outputView)

if (!repartitionedDF.isStreaming) {
// add partition and predicate pushdown detail to logs
stage.stageDetail.put("outputColumns", java.lang.Integer.valueOf(repartitionedDF.schema.length))
stage.stageDetail.put("numPartitions", java.lang.Integer.valueOf(repartitionedDF.rdd.partitions.length))

if (stage.persist) {
repartitionedDF.persist(arcContext.storageLevel)
stage.stageDetail.put("records", java.lang.Long.valueOf(repartitionedDF.count))
}
}

Option(repartitionedDF)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package ai.tripl.arc.plugins

import ai.tripl.arc.ARC
import ai.tripl.arc.api.API._
import ai.tripl.arc.config.ArcPipeline
import ai.tripl.arc.config.Error._
import ai.tripl.arc.util.TestUtils

import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}
import ai.tripl.arc.extract.ParquetExtract

class SimilarityJoinTransformSuite extends FunSuite with BeforeAndAfter {

var session: SparkSession = _
val leftView = "leftView"
val rightView = "rightView"
val outputView = "outputView"

before {
implicit val spark = SparkSession
.builder()
.master("local[*]")
.config("spark.ui.port", "9999")
.appName("Spark ETL Test")
.getOrCreate()
spark.sparkContext.setLogLevel("INFO")
implicit val logger = TestUtils.getLogger()

// set for deterministic timezone
spark.conf.set("spark.sql.session.timeZone", "UTC")

session = spark
import spark.implicits._
}

after {
session.stop()
}

test("SimilarityJoinTransformSuite") {
implicit val spark = session
import spark.implicits._
implicit val logger = TestUtils.getLogger()
implicit val arcContext = TestUtils.getARCContext(isStreaming=false)

val leftDF = Seq(
("GANSW705647478",Option("UNIT 3"),59,"INVERNESS","AVENUE","PENSHURST",2222,"NSW"),
("GANSW704384670",Option("UNIT 10"),30,"ARCHER","STREET","CHATSWOOD",2067,"NSW"),
("GANSW716607633",Option("UNIT 1"),95,"GARDINER","ROAD","ORANGE",2800,"NSW"),
("GANSW704527834",None,26,"LINKS","AVENUE","CRONULLA",2230,"NSW"),
("GANSW704579026",None,13,"VALLEY","ROAD","DENHAMS BEACH",2536,"NSW"),
("GANSW712760955",Option("UNIT 17"),39,"MACARTHUR","STREET","GRIFFITH",2680,"NSW"),
("GANSW704356027",None,66,"MILLERS","ROAD","CATTAI",2756,"NSW"),
("GANSW705978672",None,74,"CANYON","DRIVE","STANHOPE GARDENS",2768,"NSW"),
("GANSW717662718",Option("UNIT 744"),9,"ROTHSCHILD","AVENUE","ROSEBERY",2018,"NSW"),
("GANSW710590397",Option("UNIT 303"),2,"DIND","STREET","MILSONS POINT",2061,"NSW")
).toDF("gnaf_pid", "flat_number", "number_first", "street_name", "street_type", "locality_name", "postcode", "state")
leftDF.createOrReplaceTempView(leftView)

val rightDF = Seq(
(0L,"U3 59 INVERNESS AVENUE","NSW 2222 PENSHURST"),
(1L,"74 CANYON DR", "NSW 2768 STANEHOPE GDNS."),
).toDF("id", "street", "state_postcode_suburb")
rightDF.createOrReplaceTempView(rightView)

val conf = s"""{
"stages": [
{
"type": "SimilarityJoinTransform",
"name": "test",
"description": "test",
"environments": [
"production",
"test"
],
"leftView": "${leftView}",
"leftFields": ["flat_number", "number_first", "street_name", "street_type", "locality_name", "postcode", "state"],
"rightView": "${rightView}",
"rightFields": ["street", "state_postcode_suburb"],
"outputView": "${outputView}",
"threshold": 0.75,
"shingleLength": 3,
"numHashTables": 10
}
]
}"""

val pipelineEither = ArcPipeline.parseConfig(Left(conf), arcContext)

pipelineEither match {
case Left(_) => {
println(pipelineEither)
assert(false)
}
case Right((pipeline, _)) => {
val df = ARC.run(pipeline)(spark, logger, arcContext)
assert(df.get.count == 1)
}
}

}
}

0 comments on commit 5949516

Please sign in to comment.