Skip to content

Commit

Permalink
fix Union and MergeInto bug
Browse files Browse the repository at this point in the history
  • Loading branch information
iodone committed Mar 17, 2023
1 parent 58a4f58 commit 2adddcc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ trait LineageParser {
val nextColumnsLlineage = ListMap(allAssignments.map { assignment =>
(
assignment.key.asInstanceOf[Attribute],
AttributeSet(assignment.value.asInstanceOf[Attribute]))
assignment.value.references)
}: _*)
val targetTable = getPlanField[LogicalPlan]("targetTable", plan)
val sourceTable = getPlanField[LogicalPlan]("sourceTable", plan)
Expand Down Expand Up @@ -376,14 +376,22 @@ trait LineageParser {
}

case p: Union =>
// merge all children in to one derivedColumns
val childrenUnion =
p.children.map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())).map(
_.values).reduce {
(left, right) =>
left.zip(right).map(attr => attr._1 ++ attr._2)
val childrenColumnsLineage =
// support for the multi-insert statement
if (p.output.isEmpty) {
p.children
.map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]()))
.reduce(mergeColumnsLineage)
} else {
// merge all children in to one derivedColumns
val childrenUnion =
p.children.map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())).map(
_.values).reduce {
(left, right) =>
left.zip(right).map(attr => attr._1 ++ attr._2)
}
ListMap(p.output.zip(childrenUnion): _*)
}
val childrenColumnsLineage = ListMap(p.output.zip(childrenUnion): _*)
joinColumnsLineage(parentColumnsLineage, childrenColumnsLineage)

case p: LogicalRelation if p.catalogTable.nonEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
"WHEN MATCHED THEN " +
" UPDATE SET target.name = source.name, target.price = source.price " +
"WHEN NOT MATCHED THEN " +
" INSERT (id, name, price) VALUES (source.id, source.name, source.price)")
" INSERT (id, name, price) VALUES (cast(source.id as int), source.name, source.price)")
assert(ret0 == Lineage(
List("v2_catalog.db.source_t"),
List("v2_catalog.db.target_t"),
Expand Down Expand Up @@ -1287,6 +1287,25 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite
}
}

test("test the statement with FROM xxx INSERT xxx") {
withTable("t1", "t2", "t3") { _ =>
spark.sql("CREATE TABLE t1 (a string, b string) USING hive")
spark.sql("CREATE TABLE t2 (a string, b string) USING hive")
spark.sql("CREATE TABLE t3 (a string, b string) USING hive")
val ret0 = exectractLineage("from (select a,b from t1)" +
" insert overwrite table t2 select a,b where a=1" +
" insert overwrite table t3 select a,b where b=1")
assert(ret0 == Lineage(
List("default.t1"),
List("default.t2", "default.t3"),
List(
("default.t2.a", Set("default.t1.a")),
("default.t2.b", Set("default.t1.b")),
("default.t3.a", Set("default.t1.a")),
("default.t3.b", Set("default.t1.b")))))
}
}

private def exectractLineageWithoutExecuting(sql: String): Lineage = {
val parsed = spark.sessionState.sqlParser.parsePlan(sql)
val analyzed = spark.sessionState.analyzer.execute(parsed)
Expand Down

0 comments on commit 2adddcc

Please sign in to comment.