diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala index 1224da02339..f78910aedec 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala @@ -311,7 +311,7 @@ trait LineageParser { val nextColumnsLlineage = ListMap(allAssignments.map { assignment => ( assignment.key.asInstanceOf[Attribute], - AttributeSet(assignment.value.asInstanceOf[Attribute])) + assignment.value.references) }: _*) val targetTable = getPlanField[LogicalPlan]("targetTable", plan) val sourceTable = getPlanField[LogicalPlan]("sourceTable", plan) @@ -370,14 +370,22 @@ trait LineageParser { } case p: Union => - // merge all children in to one derivedColumns - val childrenUnion = - p.children.map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())).map( - _.values).reduce { - (left, right) => - left.zip(right).map(attr => attr._1 ++ attr._2) + val childrenColumnsLineage = + // support for the multi-insert statement + if (p.output.isEmpty) { + p.children + .map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())) + .reduce(mergeColumnsLineage) + } else { + // merge all children in to one derivedColumns + val childrenUnion = + p.children.map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())).map( + _.values).reduce { + (left, right) => + left.zip(right).map(attr => attr._1 ++ attr._2) + } + ListMap(p.output.zip(childrenUnion): _*) } - val childrenColumnsLineage = ListMap(p.output.zip(childrenUnion): _*) joinColumnsLineage(parentColumnsLineage, childrenColumnsLineage) case p: LogicalRelation if p.catalogTable.nonEmpty => diff --git a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala index 6180980c835..e94c88f6bc3 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala @@ -171,7 +171,7 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite "WHEN MATCHED THEN " + " UPDATE SET target.name = source.name, target.price = source.price " + "WHEN NOT MATCHED THEN " + - " INSERT (id, name, price) VALUES (source.id, source.name, source.price)") + " INSERT (id, name, price) VALUES (cast(source.id as int), source.name, source.price)") assert(ret0 == Lineage( List("v2_catalog.db.source_t"), List("v2_catalog.db.target_t"), @@ -1249,6 +1249,25 @@ class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite } } + test("test the statement with FROM xxx INSERT xxx") { + withTable("t1", "t2", "t3") { _ => + spark.sql("CREATE TABLE t1 (a string, b string) USING hive") + spark.sql("CREATE TABLE t2 (a string, b string) USING hive") + spark.sql("CREATE TABLE t3 (a string, b string) USING hive") + val ret0 = exectractLineage("from (select a,b from t1)" + + " insert overwrite table t2 select a,b where a=1" + + " insert overwrite table t3 select a,b where b=1") + assert(ret0 == Lineage( + List("default.t1"), + List("default.t2", "default.t3"), + List( + ("default.t2.a", Set("default.t1.a")), + ("default.t2.b", Set("default.t1.b")), + ("default.t3.a", Set("default.t1.a")), + ("default.t3.b", Set("default.t1.b"))))) + } + } + private def exectractLineageWithoutExecuting(sql: String): Lineage = { val parsed = spark.sessionState.sqlParser.parsePlan(sql) val analyzed = spark.sessionState.analyzer.execute(parsed)