Skip to content

Commit 4d048b3

Browse files
author
xy_xin
committed
Add suport for columns aliases
1 parent 5b738cf commit 4d048b3

File tree

5 files changed

+188
-50
lines changed

5 files changed

+188
-50
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,25 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
354354

355355
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
356356
val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
357-
val tableAlias = if (ctx.tableAlias() != null) {
357+
val (tableAlias, colsAlias) = if (ctx.tableAlias() != null) {
358358
val ident = ctx.tableAlias().strictIdentifier()
359-
if (ident != null) { Some(ident.getText) } else { None }
359+
val colList = ctx.tableAlias().identifierList()
360+
if (ident != null) {
361+
val cols = if (colList != null) {
362+
Some(visitIdentifierList(colList))
363+
} else {
364+
None
365+
}
366+
(Some(ident.getText), cols)
367+
} else {
368+
(None, None)
369+
}
360370
} else {
361-
None
371+
(None, None)
362372
}
363-
val sets = ctx.setClause().assign().asScala.map {
373+
val (attrs, values) = ctx.setClause().assign().asScala.map {
364374
kv => visitMultipartIdentifier(kv.key) -> expression(kv.value)
365-
}.toMap
375+
}.unzip
366376
val predicate = if (ctx.whereClause() != null) {
367377
Some(expression(ctx.whereClause().booleanExpression()))
368378
} else {
@@ -372,8 +382,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
372382
UpdateTableStatement(
373383
tableId,
374384
tableAlias,
375-
sets.keys.toSeq,
376-
sets.values.toSeq,
385+
colsAlias,
386+
attrs,
387+
values,
377388
predicate)
378389
}
379390

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
2222
case class UpdateTableStatement(
2323
tableName: Seq[String],
2424
tableAlias: Option[String],
25+
colsAliases: Option[Seq[String]],
2526
attrs: Seq[Seq[String]],
2627
values: Seq[Expression],
2728
condition: Option[Expression]) extends ParsedStatement

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
2323
import org.apache.spark.sql.catalog.v2.{CatalogManager, CatalogPlugin, Identifier, LookupCatalog, TableCatalog}
2424
import org.apache.spark.sql.catalog.v2.expressions.Transform
2525
import org.apache.spark.sql.catalyst.TableIdentifier
26-
import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute, UnresolvedRelation}
26+
import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases}
2727
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation}
2828
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DeleteFromTable, DropTable, Filter, LogicalPlan, ReplaceTable, ReplaceTableAsSelect, ShowTables, SubqueryAlias, UpdateTable}
2929
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DeleteFromStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement, ShowTablesStatement, UpdateTableStatement}
@@ -178,16 +178,19 @@ case class DataSourceResolution(
178178
val aliased = delete.tableAlias.map(SubqueryAlias(_, relation)).getOrElse(relation)
179179
DeleteFromTable(aliased, delete.condition)
180180

181-
case UpdateTableStatement(AsTableIdentifier(table), tableAlias, attrs, values, condition) =>
181+
case UpdateTableStatement(AsTableIdentifier(table),
182+
tableAlias, colsAlias, attrs, values, condition) =>
182183
throw new AnalysisException(
183184
s"Update table is not supported using the legacy / v1 Spark external catalog" +
184185
s" API. Identifier: $table.")
185186

186187
case update: UpdateTableStatement =>
187188
val relation = UnresolvedRelation(update.tableName)
188-
val aliased = update.tableAlias.map(SubqueryAlias(_, relation)).getOrElse(relation)
189+
val aliasedTbl = update.tableAlias.map(SubqueryAlias(_, relation)).getOrElse(relation)
190+
val aliasedTblWithCols =
191+
update.colsAliases.map(UnresolvedSubqueryColumnAliases(_, aliasedTbl)).getOrElse(aliasedTbl)
189192
val attrs = update.attrs.map(UnresolvedAttribute(_))
190-
UpdateTable(aliased, attrs, update.values, update.condition)
193+
UpdateTable(aliasedTblWithCols, attrs, update.values, update.condition)
191194

192195
case ShowTablesStatement(None, pattern) =>
193196
defaultCatalog match {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ import scala.collection.mutable
2424

2525
import org.apache.spark.sql.{AnalysisException, Strategy}
2626
import org.apache.spark.sql.catalog.v2.StagingTableCatalog
27-
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression}
27+
import org.apache.spark.sql.catalyst.expressions.{Alias, And, AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
2828
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
29-
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition, ReplaceTable, ReplaceTableAsSelect, ShowTables, UpdateTable}
29+
import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, AppendData, CreateTableAsSelect, CreateV2Table, DeleteFromTable, DescribeTable, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, Repartition, ReplaceTable, ReplaceTableAsSelect, ShowTables, UpdateTable}
3030
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
3131
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3232
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
@@ -246,27 +246,55 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
246246
}.toArray
247247
DeleteFromTableExec(r.table.asDeletable, filters) :: Nil
248248

249-
case UpdateTable(r: DataSourceV2Relation, attrs, values, condition) =>
249+
case UpdateTable(maybeRelation, attrs, values, condition) =>
250250
val nested = attrs.asInstanceOf[Seq[Any]].filterNot(_.isInstanceOf[AttributeReference])
251251
if (nested.nonEmpty) {
252-
throw new RuntimeException(s"Update only support non-nested fields. Nested: $nested")
252+
throw new AnalysisException(s"Update only support non-nested fields. Nested: $nested")
253+
}
254+
255+
val (relation, attrsNames, newValues, newCond) = maybeRelation match {
256+
case d: DataSourceV2Relation =>
257+
(d, attrs.map(_.name), values, condition)
258+
case Project(aliasList: Seq[Alias], r: DataSourceV2Relation) =>
259+
// given the aliased columns, resolve the original columns.
260+
val lookup = aliasList.map {
261+
alias => alias.name -> alias.child.asInstanceOf[AttributeReference]
262+
}.toMap
263+
def replaceAttr(input: Expression): Expression = {
264+
input.transformDown {
265+
case a: AttributeReference => lookup.getOrElse(a.name, a)
266+
case other => other
267+
}
268+
}
269+
270+
val newValues = values.map(replaceAttr)
271+
val newCond = condition.map(replaceAttr)
272+
val attrNames = attrs.map(replaceAttr).asInstanceOf[Seq[AttributeReference]].map(_.name)
273+
274+
val relationAttrNames = r.output.map(_.name)
275+
if (!attrNames.forall(relationAttrNames.contains)) {
276+
throw new AnalysisException(s"Exec update failed:" +
277+
s" cannot resolve fields ${attrNames.diff(relationAttrNames)}")
278+
}
279+
(r, attrNames, newValues, newCond)
280+
case _ =>
281+
throw new AnalysisException(s"Exec update failed: cannot resolve $maybeRelation")
253282
}
254283

255-
val attrsNames = attrs.map(_.name)
256284
// fail if any updated value cannot be converted.
257-
val updatedValues = values.map {
285+
val updatedValues = newValues.map {
258286
v => DataSourceStrategy.translateExpression(v).getOrElse(
259287
throw new AnalysisException(s"Exec update failed:" +
260288
s" cannot translate update set to source expression: $v"))
261289
}
262290
// fail if any filter cannot be converted. correctness depends on removing all matching data.
263-
val filters = condition.map(
291+
val filters = newCond.map(
264292
splitConjunctivePredicates(_).map {
265293
f => DataSourceStrategy.translateFilter(f).getOrElse(
266294
throw new AnalysisException(s"Exec update failed:" +
267295
s" cannot translate expression to source filter: $f"))
268296
}.toArray).getOrElse(Array.empty[Filter])
269-
UpdateTableExec(r.table.asUpdatable, attrsNames, updatedValues, filters)::Nil
297+
UpdateTableExec(relation.table.asUpdatable, attrsNames, updatedValues, filters)::Nil
270298

271299
case WriteToContinuousDataSource(writer, query) =>
272300
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
@@ -298,3 +326,13 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
298326
case _ => Nil
299327
}
300328
}
329+
330+
object eliminateProject {
331+
def unapply(plan: LogicalPlan): Option[DataSourceV2Relation] = {
332+
plan match {
333+
case p: DataSourceV2Relation => Some(p)
334+
case Project(_, p: DataSourceV2Relation) => Some(p)
335+
case _ => None
336+
}
337+
}
338+
}

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala

Lines changed: 116 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,12 +1770,22 @@ class DataSourceV2SQLSuite
17701770
test("Update: basic") {
17711771
val t = "testcat.ns1.ns2.tbl"
17721772
withTable(t) {
1773-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1774-
" USING foo" +
1775-
" PARTITIONED BY (id, p)")
1776-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1777-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1773+
sql(
1774+
s"""
1775+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1776+
| USING foo
1777+
| PARTITIONED BY (id, p)
1778+
""".stripMargin)
1779+
sql(
1780+
s"""
1781+
| INSERT INTO $t
1782+
| VALUES (1L, 'Herry', 26, 1),
1783+
| (2L, 'Jack', 31, 2),
1784+
| (3L, 'Lisa', 28, 3),
1785+
| (4L, 'Frank', 33, 3)
1786+
""".stripMargin)
17781787
sql(s"UPDATE $t SET name='Robert', age=32")
1788+
17791789
checkAnswer(spark.table(t),
17801790
Seq(Row(1, "Robert", 32, 1),
17811791
Row(2, "Robert", 32, 2),
@@ -1787,12 +1797,22 @@ class DataSourceV2SQLSuite
17871797
test("Update: update with where clause") {
17881798
val t = "testcat.ns1.ns2.tbl"
17891799
withTable(t) {
1790-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1791-
" USING foo" +
1792-
" PARTITIONED BY (id, p)")
1793-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1794-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1800+
sql(
1801+
s"""
1802+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1803+
| USING foo
1804+
| PARTITIONED BY (id, p)
1805+
""".stripMargin)
1806+
sql(
1807+
s"""
1808+
| INSERT INTO $t
1809+
| VALUES (1L, 'Herry', 26, 1),
1810+
| (2L, 'Jack', 31, 2),
1811+
| (3L, 'Lisa', 28, 3),
1812+
| (4L, 'Frank', 33, 3)
1813+
""".stripMargin)
17951814
sql(s"UPDATE $t SET name='Robert', age=32 where p=2")
1815+
17961816
checkAnswer(spark.table(t),
17971817
Seq(Row(1, "Herry", 26, 1),
17981818
Row(2, "Robert", 32, 2),
@@ -1804,12 +1824,22 @@ class DataSourceV2SQLSuite
18041824
test("Update: update the partition key") {
18051825
val t = "testcat.ns1.ns2.tbl"
18061826
withTable(t) {
1807-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1808-
" USING foo" +
1809-
" PARTITIONED BY (id, p)")
1810-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1811-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1827+
sql(
1828+
s"""
1829+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1830+
| USING foo
1831+
| PARTITIONED BY (id, p)
1832+
""".stripMargin)
1833+
sql(
1834+
s"""
1835+
| INSERT INTO $t
1836+
| VALUES (1L, 'Herry', 26, 1),
1837+
| (2L, 'Jack', 31, 2),
1838+
| (3L, 'Lisa', 28, 3),
1839+
| (4L, 'Frank', 33, 3)
1840+
""".stripMargin)
18121841
sql(s"UPDATE $t SET p=4 where id=4")
1842+
18131843
checkAnswer(spark.table(t),
18141844
Seq(Row(1, "Herry", 26, 1),
18151845
Row(2, "Jack", 31, 2),
@@ -1821,12 +1851,49 @@ class DataSourceV2SQLSuite
18211851
test("Update: update with aliased target table") {
18221852
val t = "testcat.ns1.ns2.tbl"
18231853
withTable(t) {
1824-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1825-
" USING foo" +
1826-
" PARTITIONED BY (id, p)")
1827-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1828-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1854+
sql(
1855+
s"""
1856+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1857+
| USING foo
1858+
| PARTITIONED BY (id, p)
1859+
""".stripMargin)
1860+
sql(
1861+
s"""
1862+
| INSERT INTO $t
1863+
| VALUES (1L, 'Herry', 26, 1),
1864+
| (2L, 'Jack', 31, 2),
1865+
| (3L, 'Lisa', 28, 3),
1866+
| (4L, 'Frank', 33, 3)
1867+
""".stripMargin)
18291868
sql(s"UPDATE $t AS tbl SET tbl.name='Robert', tbl.age=32 where p=2")
1869+
1870+
checkAnswer(spark.table(t),
1871+
Seq(Row(1, "Herry", 26, 1),
1872+
Row(2, "Robert", 32, 2),
1873+
Row(3, "Lisa", 28, 3),
1874+
Row(4, "Frank", 33, 3)))
1875+
}
1876+
}
1877+
1878+
test("Update: update with aliased target table columns") {
1879+
val t = "testcat.ns1.ns2.tbl"
1880+
withTable(t) {
1881+
sql(
1882+
s"""
1883+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1884+
| USING foo
1885+
| PARTITIONED BY (id, p)
1886+
""".stripMargin)
1887+
sql(
1888+
s"""
1889+
| INSERT INTO $t
1890+
| VALUES (1L, 'Herry', 26, 1),
1891+
| (2L, 'Jack', 31, 2),
1892+
| (3L, 'Lisa', 28, 3),
1893+
| (4L, 'Frank', 33, 3)
1894+
""".stripMargin)
1895+
sql(s"UPDATE $t AS tbl(a, b, c, d) SET b='Robert', c=32 where d=2")
1896+
18301897
checkAnswer(spark.table(t),
18311898
Seq(Row(1, "Herry", 26, 1),
18321899
Row(2, "Robert", 32, 2),
@@ -1842,7 +1909,7 @@ class DataSourceV2SQLSuite
18421909
sql(s"INSERT INTO $t SELECT 1, named_struct('x', 1.0D, 'y', 1.0D)")
18431910
checkAnswer(spark.table(t), Seq(Row(1, Row(1.0, 1.0))))
18441911

1845-
val exc = intercept[RuntimeException] {
1912+
val exc = intercept[AnalysisException] {
18461913
sql(s"UPDATE $t tbl SET tbl.point.x='2.0D', tbl.point.y='3.0D'")
18471914
}
18481915

@@ -1854,11 +1921,20 @@ class DataSourceV2SQLSuite
18541921
test("Update: fail if the value expression in set clause cannot be converted") {
18551922
val t = "testcat.ns1.ns2.tbl"
18561923
withTable(t) {
1857-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1858-
" USING foo" +
1859-
" PARTITIONED BY (id, p)")
1860-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1861-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1924+
sql(
1925+
s"""
1926+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1927+
| USING foo
1928+
| PARTITIONED BY (id, p)
1929+
""".stripMargin)
1930+
sql(
1931+
s"""
1932+
| INSERT INTO $t
1933+
| VALUES (1L, 'Herry', 26, 1),
1934+
| (2L, 'Jack', 31, 2),
1935+
| (3L, 'Lisa', 28, 3),
1936+
| (4L, 'Frank', 33, 3)
1937+
""".stripMargin)
18621938
val exc = intercept[AnalysisException] {
18631939
sql(s"UPDATE $t tbl SET tbl.p=tbl.p + 1 WHERE id = 3")
18641940
}
@@ -1871,11 +1947,20 @@ class DataSourceV2SQLSuite
18711947
test("Update: fail if has subquery") {
18721948
val t = "testcat.ns1.ns2.tbl"
18731949
withTable(t) {
1874-
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1875-
" USING foo" +
1876-
" PARTITIONED BY (id, p)")
1877-
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1878-
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1950+
sql(
1951+
s"""
1952+
| CREATE TABLE $t (id bigint, name string, age int, p int)
1953+
| USING foo
1954+
| PARTITIONED BY (id, p)
1955+
""".stripMargin)
1956+
sql(
1957+
s"""
1958+
| INSERT INTO $t
1959+
| VALUES (1L, 'Herry', 26, 1),
1960+
| (2L, 'Jack', 31, 2),
1961+
| (3L, 'Lisa', 28, 3),
1962+
| (4L, 'Frank', 33, 3)
1963+
""".stripMargin)
18791964
val exc = intercept[AnalysisException] {
18801965
sql(s"UPDATE $t SET name='Robert' WHERE id IN (SELECT id FROM $t)")
18811966
}

0 commit comments

Comments
 (0)