Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ someRuleSet.addMinMaxRules("Retail_Price_Validation", col("retail_price"), Bound
### Categorical Rules
There are two types of categorical rules which are used to validate against a pre-defined list of valid
values. As of 0.2 accepted categorical types are String, Double, Int, Long but any types outside of this can
be input as an array() column of any type so long as it can be evaulated against the intput column
be input as an array() column of any type so long as it can be evaluated against the input column.

```scala
val catNumerics = Array(
Rule("Valid_Stores", col("store_id"), Lookups.validStoreIDs),
Expand All @@ -187,6 +188,18 @@ Rule("Valid_Regions", col("region"), Lookups.validRegions)
)
```

An optional `ignoreCase` parameter can be specified when evaluating against a list of String values to ignore or apply
case-sensitivity. By default, input columns will be evaluated against a list of Strings with case-sensitivity applied.
```scala
Rule("Valid_Regions", col("region"), Lookups.validRegions, ignoreCase=true)
```

Furthermore, the evaluation of categorical rules can be inverted by specifying `invertMatch=true` as a parameter.
This can be handy when defining a Rule that an input column cannot match list of invalid values. For example:
```scala
Rule("Invalid_Skus", col("sku"), Lookups.invalidSkus, invertMatch=true)
```

### Validation
Now that you have some rules built up... it's time to build the ruleset and validate it. As mentioned above,
the dataframe can be a simple df or a grouped df by passing column[s] to perform validation at the
Expand Down
19 changes: 10 additions & 9 deletions demo/Example.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ object Example extends App with SparkSessionWrapper {

val catNumerics = Array(
Rule("Valid_Stores", col("store_id"), Lookups.validStoreIDs),
Rule("Valid_Skus", col("sku"), Lookups.validSkus)
Rule("Valid_Skus", col("sku"), Lookups.validSkus),
Rule("Invalid_Skus", col("sku"), Lookups.invalidSkus, invertMatch=true)
)

val catStrings = Array(
Rule("Valid_Regions", col("region"), Lookups.validRegions)
Rule("Valid_Regions", col("region"), Lookups.validRegions, ignoreCase=true)
)

//TODO - validate datetime
Expand All @@ -76,18 +77,18 @@ object Example extends App with SparkSessionWrapper {
.withColumn("create_dt", 'create_ts.cast("date"))

// Doing the validation
// The validate method will return the rules report dataframe which breaks down which rules passed and which
// rules failed and how/why. The second return value returns a boolean to determine whether or not all tests passed
// val (rulesReport, passed) = RuleSet(df, Array("store_id"))
val (rulesReport, passed) = RuleSet(df)
// The validate method will return two reports - a complete report and a summary report.
// The complete report is verbose and will add all rule validations to the right side of the original
// df passed into RuleSet, while the summary report will contain all of the rows that failed one or more
// Rule evaluations.
val validationResults = RuleSet(df)
.add(specializedRules)
.add(minMaxPriceRules)
.add(catNumerics)
.add(catStrings)
.validate(2)
.validate()

rulesReport.show(200, false)
// rulesReport.printSchema()
validationResults.completeReport.show(200, false)


}
Binary file modified demo/Rules_Engine_Examples.dbc
Binary file not shown.
49 changes: 25 additions & 24 deletions demo/Rules_Engine_Examples.html

Large diffs are not rendered by default.

72 changes: 64 additions & 8 deletions src/main/scala/com/databricks/labs/validation/Rule.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.databricks.labs.validation

import com.databricks.labs.validation.utils.Structures.{Bounds, ValidationException}
import com.databricks.labs.validation.utils.Structures.Bounds
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{array, lit}
import org.apache.spark.sql.types.BooleanType

import java.util.UUID

/**
* Definition of a rule
Expand All @@ -21,6 +18,8 @@ class Rule(
private var _validNumerics: Column = array(lit(null).cast("double"))
private var _validStrings: Column = array(lit(null).cast("string"))
private var _implicitBoolean: Boolean = false
private var _ignoreCase: Boolean = false
private var _invertMatch: Boolean = false
val inputColumnName: String = inputColumn.expr.toString().replace("'", "")

override def toString: String = {
Expand All @@ -47,8 +46,8 @@ class Rule(
this
}

private def setValidStrings(value: Array[String]): this.type = {
_validStrings = lit(value)
private def setValidStrings(value: Array[String], ignoreCase: Boolean): this.type = {
_validStrings = if(ignoreCase) lit(value.map(_.toLowerCase)) else lit(value)
inputColumn.expr.children.map(_.prettyName)
this
}
Expand All @@ -63,6 +62,16 @@ class Rule(
this
}

private def setIgnoreCase(value: Boolean): this.type = {
_ignoreCase = value
this
}

private def setInvertMatch(value: Boolean): this.type = {
_invertMatch = value
this
}

def boundaries: Bounds = _boundaries

def validNumerics: Column = _validNumerics
Expand All @@ -73,6 +82,10 @@ class Rule(

def isImplicitBool: Boolean = _implicitBoolean

def ignoreCase: Boolean = _ignoreCase

def invertMatch: Boolean = _invertMatch

def isAgg: Boolean = {
inputColumn.expr.prettyName == "aggregateexpression" ||
inputColumn.expr.children.map(_.prettyName).contains("aggregateexpression")
Expand Down Expand Up @@ -114,6 +127,18 @@ object Rule {
.setValidExpr(validExpr)
}

def apply(
ruleName: String,
column: Column,
validNumerics: Array[Double],
invertMatch: Boolean
): Rule = {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics)
.setInvertMatch(invertMatch)
}

def apply(
ruleName: String,
column: Column,
Expand All @@ -122,6 +147,19 @@ object Rule {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics)
.setInvertMatch(false)
}

def apply(
ruleName: String,
column: Column,
validNumerics: Array[Long],
invertMatch: Boolean
): Rule = {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setInvertMatch(invertMatch)
}

def apply(
Expand All @@ -132,6 +170,19 @@ object Rule {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setInvertMatch(false)
}

def apply(
ruleName: String,
column: Column,
validNumerics: Array[Int],
invertMatch: Boolean
): Rule = {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setInvertMatch(invertMatch)
}

def apply(
Expand All @@ -142,16 +193,21 @@ object Rule {

new Rule(ruleName, column, RuleType.ValidateNumerics)
.setValidNumerics(validNumerics.map(_.toString.toDouble))
.setInvertMatch(false)
}

def apply(
ruleName: String,
column: Column,
validStrings: Array[String]
validStrings: Array[String],
ignoreCase: Boolean = false,
invertMatch: Boolean = false
): Rule = {

new Rule(ruleName, column, RuleType.ValidateStrings)
.setValidStrings(validStrings)
.setValidStrings(validStrings, ignoreCase)
.setIgnoreCase(ignoreCase)
.setInvertMatch(invertMatch)
}

}
19 changes: 10 additions & 9 deletions src/main/scala/com/databricks/labs/validation/RuleSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RuleSet extends SparkSessionWrapper {

private def setGroupByCols(value: Seq[String]): this.type = {
_groupBys = value
_isGrouped = true
_isGrouped = value.nonEmpty
this
}

Expand Down Expand Up @@ -110,15 +110,16 @@ class RuleSet extends SparkSessionWrapper {
}

/**
* Merge two rule sets by adding one rule set to another
*
* @param ruleSet RuleSet to be added
* @return RuleSet
*/
* Merge two rule sets by adding one rule set to another
*
* @param ruleSet RuleSet to be added
* @return RuleSet
*/
def add(ruleSet: RuleSet): RuleSet = {
new RuleSet().setDF(ruleSet.getDf)
.setIsGrouped(ruleSet.isGrouped)
.add(ruleSet.getRules)
val addtnlGroupBys = ruleSet.getGroupBys diff this.getGroupBys
val mergedGroupBys = this.getGroupBys ++ addtnlGroupBys
this.add(ruleSet.getRules)
.setGroupByCols(mergedGroupBys)
}

/**
Expand Down
7 changes: 5 additions & 2 deletions src/main/scala/com/databricks/labs/validation/Validator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
rule.inputColumn.cast("string").alias("actual")
).alias(rule.ruleName)
case RuleType.ValidateNumerics =>
val ruleExpr = if(rule.invertMatch) not(array_contains(rule.validNumerics, rule.inputColumn)) else array_contains(rule.validNumerics, rule.inputColumn)
struct(
lit(rule.ruleName).alias("ruleName"),
array_contains(rule.validNumerics, rule.inputColumn).alias("passed"),
ruleExpr.alias("passed"),
rule.validNumerics.cast("string").alias("permitted"),
rule.inputColumn.cast("string").alias("actual")
).alias(rule.ruleName)
case RuleType.ValidateStrings =>
val ruleValue = if(rule.ignoreCase) lower(rule.inputColumn) else rule.inputColumn
val ruleExpr = if(rule.invertMatch) not(array_contains(rule.validStrings, ruleValue)) else array_contains(rule.validStrings, ruleValue)
struct(
lit(rule.ruleName).alias("ruleName"),
array_contains(rule.validStrings, rule.inputColumn).alias("passed"),
ruleExpr.alias("passed"),
rule.validStrings.cast("string").alias("permitted"),
rule.inputColumn.cast("string").alias("actual")
).alias(rule.ruleName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ object Lookups {

final val validRegions = Array("Northeast", "Southeast", "Midwest", "Northwest", "Southcentral", "Southwest")

final val validSkus = Array(123456, 122987,123256, 173544, 163212, 365423, 168212)
final val validSkus = Array(123456, 122987, 123256, 173544, 163212, 365423, 168212)

final val invalidSkus = Array(9123456, 9122987, 9123256, 9173544, 9163212, 9365423, 9168212)

}

Expand Down
Loading