Skip to content

Commit a76062c

Browse files
author
xy_xin
committed
Updated according to review comments
1 parent a58a87b commit a76062c

File tree

7 files changed

+127
-30
lines changed

7 files changed

+127
-30
lines changed

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ statement
215215
| SET .*? #setConfiguration
216216
| RESET #resetConfiguration
217217
| DELETE FROM multipartIdentifier tableAlias whereClause #deleteFromTable
218-
| UPDATE multipartIdentifier tableAlias setClause whereClause #updateTable
218+
| UPDATE multipartIdentifier tableAlias setClause (whereClause)? #updateTable
219219
| unsupportedHiveNativeCommands .*? #failNativeCommand
220220
;
221221

@@ -480,7 +480,7 @@ setClause
480480
;
481481

482482
assign
483-
: key=multipartIdentifier EQ value=valueExpression
483+
: key=multipartIdentifier EQ value=expression
484484
;
485485

486486
whereClause

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,13 +363,18 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
363363
val sets = ctx.setClause().assign().asScala.map {
364364
kv => visitMultipartIdentifier(kv.key) -> expression(kv.value)
365365
}.toMap
366+
val predicate = if (ctx.whereClause() != null) {
367+
Some(expression(ctx.whereClause().booleanExpression()))
368+
} else {
369+
None
370+
}
366371

367372
UpdateTableStatement(
368373
tableId,
369374
tableAlias,
370375
sets.keys.toSeq,
371376
sets.values.toSeq,
372-
expression(ctx.whereClause().booleanExpression()))
377+
predicate)
373378
}
374379

375380
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ case class UpdateTable(
578578
child: LogicalPlan,
579579
attrs: Seq[Attribute],
580580
values: Seq[Expression],
581-
condition: Expression) extends Command {
581+
condition: Option[Expression]) extends Command {
582582

583583
override def children: Seq[LogicalPlan] = child :: Nil
584584
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/UpdateTableStatement.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ case class UpdateTableStatement(
2424
tableAlias: Option[String],
2525
attrs: Seq[Seq[String]],
2626
values: Seq[Expression],
27-
condition: Expression) extends ParsedStatement
27+
condition: Option[Expression]) extends ParsedStatement

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
3131
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
3232
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
3333
import org.apache.spark.sql.sources
34+
import org.apache.spark.sql.sources.Filter
3435
import org.apache.spark.sql.sources.v2.TableCapability
3536
import org.apache.spark.sql.sources.v2.reader._
3637
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
@@ -246,6 +247,11 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
246247
DeleteFromTableExec(r.table.asDeletable, filters) :: Nil
247248

248249
case UpdateTable(r: DataSourceV2Relation, attrs, values, condition) =>
250+
val nested = attrs.asInstanceOf[Seq[Any]].filterNot(_.isInstanceOf[AttributeReference])
251+
if (nested.nonEmpty) {
252+
throw new RuntimeException(s"Update only support non-nested fields. Nested: $nested")
253+
}
254+
249255
val attrsNames = attrs.map(_.name)
250256
// fail if any updated value cannot be converted.
251257
val updatedValues = values.map {
@@ -254,11 +260,12 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
254260
s" cannot translate update set to source expression: $v"))
255261
}
256262
// fail if any filter cannot be converted. correctness depends on removing all matching data.
257-
val filters = splitConjunctivePredicates(condition).map {
258-
f => DataSourceStrategy.translateFilter(f).getOrElse(
259-
throw new AnalysisException(s"Exec update failed:" +
260-
s" cannot translate expression to source filter: $f"))
261-
}.toArray
263+
val filters = condition.map(
264+
splitConjunctivePredicates(_).map {
265+
f => DataSourceStrategy.translateFilter(f).getOrElse(
266+
throw new AnalysisException(s"Exec update failed:" +
267+
s" cannot translate expression to source filter: $f"))
268+
}.toArray).getOrElse(Array.empty[Filter])
262269
UpdateTableExec(r.table.asUpdatable, attrsNames, updatedValues, filters)::Nil
263270

264271
case WriteToContinuousDataSource(writer, query) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object V2WriteSupportCheck extends (LogicalPlan => Unit) {
5757
}
5858

5959
case UpdateTable(_, _, _, condition) =>
60-
if (SubqueryExpression.hasSubquery(condition)) {
60+
if (condition.exists(SubqueryExpression.hasSubquery)) {
6161
failAnalysis(s"Update by condition with subquery is not supported: $condition")
6262
}
6363

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,49 +1770,134 @@ class DataSourceV2SQLSuite
17701770
test("Update: basic") {
17711771
val t = "testcat.ns1.ns2.tbl"
17721772
withTable(t) {
1773-
sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
1774-
sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
1775-
sql(s"UPDATE $t SET data='d' WHERE id = 2")
1776-
checkAnswer(spark.table(t), Seq(
1777-
Row(2, "d", 2), Row(2, "d", 3), Row(3, "c", 3)))
1773+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1774+
" USING foo" +
1775+
" PARTITIONED BY (id, p)")
1776+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1777+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1778+
sql(s"UPDATE $t SET name='Robert', age=32")
1779+
checkAnswer(spark.table(t),
1780+
Seq(Row(1, "Robert", 32, 1),
1781+
Row(2, "Robert", 32, 2),
1782+
Row(3, "Robert", 32, 3),
1783+
Row(4, "Robert", 32, 3)))
17781784
}
17791785
}
17801786

1781-
test("Update: alias") {
1787+
test("Update: update with where clause") {
17821788
val t = "testcat.ns1.ns2.tbl"
17831789
withTable(t) {
1784-
sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
1785-
sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
1786-
sql(s"UPDATE $t tbl SET tbl.data='d' WHERE id = 2")
1787-
checkAnswer(spark.table(t), Seq(
1788-
Row(2, "d", 2), Row(2, "d", 3), Row(3, "c", 3)))
1790+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1791+
" USING foo" +
1792+
" PARTITIONED BY (id, p)")
1793+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1794+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1795+
sql(s"UPDATE $t SET name='Robert', age=32 where p=2")
1796+
checkAnswer(spark.table(t),
1797+
Seq(Row(1, "Herry", 26, 1),
1798+
Row(2, "Robert", 32, 2),
1799+
Row(3, "Lisa", 28, 3),
1800+
Row(4, "Frank", 33, 3)))
1801+
}
1802+
}
1803+
1804+
test("Update: update the partition key") {
1805+
val t = "testcat.ns1.ns2.tbl"
1806+
withTable(t) {
1807+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1808+
" USING foo" +
1809+
" PARTITIONED BY (id, p)")
1810+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1811+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1812+
sql(s"UPDATE $t SET p=4 where id=4")
1813+
checkAnswer(spark.table(t),
1814+
Seq(Row(1, "Herry", 26, 1),
1815+
Row(2, "Jack", 31, 2),
1816+
Row(3, "Lisa", 28, 3),
1817+
Row(4, "Frank", 33, 4)))
1818+
}
1819+
}
1820+
1821+
test("Update: update with aliased target table - 1") {
1822+
val t = "testcat.ns1.ns2.tbl"
1823+
withTable(t) {
1824+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1825+
" USING foo" +
1826+
" PARTITIONED BY (id, p)")
1827+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1828+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1829+
sql(s"UPDATE $t tbl SET tbl.name='Robert', tbl.age=32 where p=2")
1830+
checkAnswer(spark.table(t),
1831+
Seq(Row(1, "Herry", 26, 1),
1832+
Row(2, "Robert", 32, 2),
1833+
Row(3, "Lisa", 28, 3),
1834+
Row(4, "Frank", 33, 3)))
1835+
}
1836+
}
1837+
1838+
test("Update: update with aliased target table - 2") {
1839+
val t = "testcat.ns1.ns2.tbl"
1840+
withTable(t) {
1841+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1842+
" USING foo" +
1843+
" PARTITIONED BY (id, p)")
1844+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1845+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
1846+
sql(s"UPDATE $t AS tbl SET tbl.name='Robert', tbl.age=32 where p=2")
1847+
checkAnswer(spark.table(t),
1848+
Seq(Row(1, "Herry", 26, 1),
1849+
Row(2, "Robert", 32, 2),
1850+
Row(3, "Lisa", 28, 3),
1851+
Row(4, "Frank", 33, 3)))
1852+
}
1853+
}
1854+
1855+
test("Update: update nested field") {
1856+
val t = "testcat.ns1.ns2.tbl"
1857+
withTable(t) {
1858+
sql(s"CREATE TABLE $t (id int, point struct<x: double, y: double>) USING foo")
1859+
sql(s"INSERT INTO $t SELECT 1, named_struct('x', 1.0D, 'y', 1.0D)")
1860+
checkAnswer(spark.table(t), Seq(Row(1, Row(1.0, 1.0))))
1861+
1862+
val exc = intercept[RuntimeException] {
1863+
sql(s"UPDATE $t tbl SET tbl.point.x='2.0D', tbl.point.y='3.0D'")
1864+
}
1865+
1866+
checkAnswer(spark.table(t), Seq(Row(1, Row(1.0, 1.0))))
1867+
assert(exc.getMessage.contains("Update only support non-nested fields."))
17891868
}
17901869
}
17911870

17921871
test("Update: fail if the value expression in set clause cannot be converted") {
17931872
val t = "testcat.ns1.ns2.tbl"
17941873
withTable(t) {
1795-
sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
1796-
sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
1874+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1875+
" USING foo" +
1876+
" PARTITIONED BY (id, p)")
1877+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1878+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
17971879
val exc = intercept[AnalysisException] {
1798-
sql(s"UPDATE $t tbl SET tbl.id=tbl.id + 1 WHERE id = 3")
1880+
sql(s"UPDATE $t tbl SET tbl.p=tbl.p + 1 WHERE id = 3")
17991881
}
18001882

1801-
assert(spark.table(t).filter("id=3").select("data").head().getString(0) == "c")
1883+
assert(spark.table(t).filter("id=3").select("p").head().getInt(0) == 3)
18021884
assert(exc.getMessage.contains("Exec update failed: "))
18031885
}
18041886
}
18051887

18061888
test("Update: fail if has subquery") {
18071889
val t = "testcat.ns1.ns2.tbl"
18081890
withTable(t) {
1809-
sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
1810-
sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
1891+
sql(s"CREATE TABLE $t (id bigint, name string, age int, p int)" +
1892+
" USING foo" +
1893+
" PARTITIONED BY (id, p)")
1894+
sql(s"INSERT INTO $t VALUES (1L, 'Herry', 26, 1)," +
1895+
s" (2L, 'Jack', 31, 2), (3L, 'Lisa', 28, 3), (4L, 'Frank', 33, 3)")
18111896
val exc = intercept[AnalysisException] {
1812-
sql(s"UPDATE $t SET data='d' WHERE id IN (SELECT id FROM $t)")
1897+
sql(s"UPDATE $t SET name='Robert' WHERE id IN (SELECT id FROM $t)")
18131898
}
18141899

1815-
assert(spark.table(t).filter("id=3").select("data").head().getString(0) == "c")
1900+
assert(spark.table(t).filter("id=4").select("name").head().getString(0) == "Frank")
18161901
assert(exc.getMessage.contains("Update by condition with subquery is not supported"))
18171902
}
18181903
}

0 commit comments

Comments
 (0)