Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPowers committed Nov 26, 2022
0 parents commit a266c7c
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 0 deletions.
Empty file added README.md
Empty file.
56 changes: 56 additions & 0 deletions mack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from delta import *
import pyspark


def type_2_scd_upsert(path, updates_df, primaryKey, attrColNames):
return type_2_scd_generic_upsert(path, updates_df, primaryKey, attrColNames, "is_current", "effective_time", "end_time")


def type_2_scd_generic_upsert(path, updates_df, primaryKey, attrColNames, isCurrentColName, effectiveTimeColName, endTimeColName):
baseTable = DeltaTable.forPath(pyspark.sql.SparkSession.getActiveSession(), path)
# // validate the existing Delta table
# baseColNames = baseTable.toDF.columns.toSeq
# requiredBaseColNames = Seq(primaryKey) ++ attrColNames ++ Seq(isCurrentColName, effectiveTimeColName, endTimeColName)
# // @todo move the validation logic to a separate abstraction
# if (baseColNames.sorted != requiredBaseColNames.sorted) {
# throw JodieValidationError(f"The base table has these columns '$baseColNames', but these columns are required '$requiredBaseColNames'")
# }
# // validate the updates DataFrame
# updatesColNames = updates_df.columns.toSeq
# requiredUpdatesColNames = Seq(primaryKey) ++ attrColNames ++ Seq(effectiveTimeColName)
# if (updatesColNames.sorted != requiredUpdatesColNames.sorted) {
# throw JodieValidationError(f"The updates DataFrame has these columns '$updatesColNames', but these columns are required '$requiredUpdatesColNames'")
# }

# perform the upsert
# updatesAttrs = attrColNames.map(attr => f"updates.$attr <> base.$attr").mkString(" OR ")
updatesAttrs = list(map(lambda attr: f"updates.{attr} <> base.{attr}", attrColNames))
updatesAttrs = " OR ".join(updatesAttrs)
# stagedUpdatesAttrs = attrColNames.map(attr => f"staged_updates.$attr <> base.$attr").mkString(" OR ")
stagedUpdatesAttrs = list(map(lambda attr: f"staged_updates.{attr} <> base.{attr}", attrColNames))
stagedUpdatesAttrs = " OR ".join(stagedUpdatesAttrs)
stagedPart1 = updates_df.alias("updates").join(baseTable.toDF().alias("base"), primaryKey).where(f"base.{isCurrentColName} = true AND ({updatesAttrs})").selectExpr("NULL as mergeKey", "updates.*")
# stagedPart1 = updates_df.as("updates").join(baseTable.toDF().as("base"), primaryKey).where(f"base.{isCurrentColName} = true AND ({updatesAttrs})").selectExpr("NULL as mergeKey", "updates.*")
stagedPart2 = updates_df.selectExpr(f"{primaryKey} as mergeKey", "*")
stagedUpdates = stagedPart1.union(stagedPart2)
# thing = attrColNames.map(attr => (attr, f"staged_updates.{attr}")).toMap
thing = {}
for attr in attrColNames:
thing[attr] = f"staged_updates.{attr}"
thing2 = {
primaryKey: f"staged_updates.{primaryKey}",
isCurrentColName: "true",
effectiveTimeColName: f"staged_updates.{effectiveTimeColName}",
endTimeColName: "null"
}
res_thing = {**thing, **thing2}
res = (baseTable
.alias("base")
.merge(
source = stagedUpdates.alias("staged_updates"),
condition = pyspark.sql.functions.expr(f"base.{primaryKey} = mergeKey AND base.{isCurrentColName} = true AND ({stagedUpdatesAttrs})"))
.whenMatchedUpdate(
set = {isCurrentColName: "false", endTimeColName: f"staged_updates.{effectiveTimeColName}"})
.whenNotMatchedInsert(values = res_thing)
.execute())
return res
176 changes: 176 additions & 0 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[tool.poetry]
name = "mack"
version = "0.1.0"
description = ""
authors = ["Matthew Powers <matthewkevinpowers@gmail.com>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.9"

[tool.poetry.dev-dependencies]
pyspark = "3.3.1"
delta-spark = "2.1.1"
pytest = "3.2.2"
chispa = "0.9.2"
pytest-describe = "^1.0.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file added tests/__init__.py
Empty file.
74 changes: 74 additions & 0 deletions tests/test_public_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest
import chispa
import pyspark
from delta import *
import datetime
from pyspark.sql.types import StructType,StructField, StringType, IntegerType, BooleanType, DateType
import mack

builder = (
pyspark.sql.SparkSession.builder.appName("MyApp")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
)

spark = configure_spark_with_delta_pip(builder).getOrCreate()

def test_type_2_scd_generic_upsert():
# versions = [0, 1]
# for v in versions:
# actual_df = spark.read.format("delta").option("versionAsOf", v).load("out/tables/generated/reference_table_1/delta")
# expected_df = spark.read.format("parquet").load(f"out/tables/generated/reference_table_1/expected/v{v}/table_content.parquet")
# chispa.assert_df_equality(actual_df, expected_df)

path = "tmp/delta-upsert-date"
# // create Delta Lake
data2 = [
(1, "A", True, datetime.datetime(2019, 1, 1), None),
(2, "B", True, datetime.datetime(2019, 1, 1), None),
(4, "D", True, datetime.datetime(2019, 1, 1), None),
]

schema = StructType([
StructField("pkey",IntegerType(),True),
StructField("attr",StringType(),True),
StructField("cur",BooleanType(),True),
StructField("effective_date", DateType(), True),
StructField("end_date", DateType(), True)
])

df = spark.createDataFrame(data=data2,schema=schema)
df.write.format("delta").save(path)

# create updates DF
updatesDF = spark.createDataFrame([
(3, "C", datetime.datetime(2020, 9, 15)), # new value
(2, "Z", datetime.datetime(2020, 1, 1)), # value to upsert
]).toDF("pkey", "attr", "effective_date")

# perform upsert
mack.type_2_scd_generic_upsert(path, updatesDF, "pkey", ["attr"], "cur", "effective_date", "end_date")

actual_df = spark.read.format("delta").load(path)

expected_df = spark.createDataFrame([
(2, "B", False, datetime.datetime(2019, 1, 1), datetime.datetime(2020, 1, 1)),
(3, "C", True, datetime.datetime(2020, 9, 15), None),
(2, "Z", True, datetime.datetime(2020, 1, 1), None),
(4, "D", True, datetime.datetime(2019, 1, 1), None),
(1, "A", True, datetime.datetime(2019, 1, 1), None),
], schema)

chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True)

# val expected = Seq(
# (2, "B", false, Date.valueOf("2019-01-01"), Date.valueOf("2020-01-01")),
# (3, "C", true, Date.valueOf("2020-09-15"), null),
# (2, "Z", true, Date.valueOf("2020-01-01"), null),
# (4, "D", true, Date.valueOf("2019-01-01"), null),
# (1, "A", true, Date.valueOf("2019-01-01"), null),
# ).toDF("pkey", "attr", "cur", "effective_date", "end_date")
# assertSmallDataFrameEquality(res, expected, orderedComparison = false, ignoreNullable = true)

0 comments on commit a266c7c

Please sign in to comment.