Skip to content

Commit bcafbdb

Browse files
author
xy_xin
committed
Add UPDATE support for DataSource V2
1 parent 76ebf22 commit bcafbdb

File tree

18 files changed

+561
-23
lines changed

18 files changed

+561
-23
lines changed

core/src/main/scala/org/apache/spark/internal/Logging.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ trait Logging {
3535

3636
// Make the log field transient so that objects with Logging can
3737
// be serialized and used on another machine
38-
@transient private var log_ : Logger = null
38+
@transient private var log_ : Logger = nullsql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
3939

4040
// Method to get the logger name for this object
4141
protected def logName = {

docs/sql-keywords.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ Below is a list of all the keywords in Spark SQL.
280280
<tr><td>UNKNOWN</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
281281
<tr><td>UNLOCK</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
282282
<tr><td>UNSET</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
283+
<tr><td>UPDATE</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
283284
<tr><td>USE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
284285
<tr><td>USER</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
285286
<tr><td>USING</td><td>reserved</td><td>strict-non-reserved</td><td>reserved</td></tr>

sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ statement
217217
| SET .*? #setConfiguration
218218
| RESET #resetConfiguration
219219
| DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable
220+
| UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable
220221
| unsupportedHiveNativeCommands .*? #failNativeCommand
221222
;
222223

@@ -476,6 +477,14 @@ selectClause
476477
: SELECT (hints+=hint)* setQuantifier? namedExpressionSeq
477478
;
478479

480+
setClause
481+
: SET assign (',' assign)*
482+
;
483+
484+
assign
485+
: key=multipartIdentifier EQ value=expression
486+
;
487+
479488
whereClause
480489
: WHERE booleanExpression
481490
;
@@ -1085,6 +1094,7 @@ ansiNonReserved
10851094
| UNCACHE
10861095
| UNLOCK
10871096
| UNSET
1097+
| UPDATE
10881098
| USE
10891099
| VALUES
10901100
| VIEW
@@ -1355,6 +1365,7 @@ nonReserved
13551365
| UNKNOWN
13561366
| UNLOCK
13571367
| UNSET
1368+
| UPDATE
13581369
| USE
13591370
| USER
13601371
| VALUES
@@ -1622,6 +1633,7 @@ UNIQUE: 'UNIQUE';
16221633
UNKNOWN: 'UNKNOWN';
16231634
UNLOCK: 'UNLOCK';
16241635
UNSET: 'UNSET';
1636+
UPDATE: 'UPDATE';
16251637
USE: 'USE';
16261638
USER: 'USER';
16271639
USING: 'USING';
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connector.catalog;
19+
20+
import java.util.Map;
21+
22+
import org.apache.spark.annotation.Experimental;
23+
import org.apache.spark.sql.connector.expressions.Expression;
24+
import org.apache.spark.sql.sources.Filter;
25+
26+
/**
27+
* A mix-in interface for {@link Table} update support. Data sources can implement this
28+
* interface to provide the ability to update data that matches filter expressions
29+
* with the given sets.
30+
*/
31+
@Experimental
32+
public interface SupportsUpdate {
33+
34+
/**
35+
* Update data that matches filter expressions with the given sets for a data source table.
36+
* <p>
37+
* Rows will be updated with the given values iff all of the filter expressions match.
38+
* That is, the expressions must be interpreted as a set of filters that are ANDed together.
39+
* <p>
40+
* Implementations may reject a update operation if the update isn't possible without significant
41+
* effort or it cannot deal with the sets expression. For example, partitioned data sources may
42+
* reject updates that do not filter by partition columns because the filter may require
43+
* rewriting files. The update may also be rejected if the update requires a complex computation
44+
* that the data source does not support.
45+
* To reject an update implementations should throw {@link IllegalArgumentException} with a clear
46+
* error message that identifies which expression was rejected.
47+
*
48+
* @param sets the fields to be updated and the corresponding updated values in form of
49+
* key-value pairs in a map. The value can be a literal, or a simple expression
50+
* like {{{ originalValue + 1 }}}.
51+
* @param filters filter expressions, used to select rows to update when all expressions match
52+
* @throws IllegalArgumentException If the update is rejected due to required effort
53+
* or unsupported update expression
54+
*/
55+
void updateWhere(Map<String, Expression> sets, Filter[] filters);
56+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,6 +1761,8 @@ class Analyzer(
17611761
resolveSubQueries(q, q.children)
17621762
case d: DeleteFromTable if d.childrenResolved =>
17631763
resolveSubQueries(d, d.children)
1764+
case u: UpdateTable if u.childrenResolved =>
1765+
resolveSubQueries(u, u.children)
17641766
}
17651767
}
17661768

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,19 +587,20 @@ trait CheckAnalysis extends PredicateHelper {
587587
// Only certain operators are allowed to host subquery expression containing
588588
// outer references.
589589
plan match {
590-
case _: Filter | _: Aggregate | _: Project | _: DeleteFromTable => // Ok
590+
case _: Filter | _: Aggregate | _: Project | _: DeleteFromTable | _: UpdateTable =>
591+
// Ok
591592
case other => failAnalysis(
592593
"Correlated scalar sub-queries can only be used in a " +
593-
s"Filter/Aggregate/Project: $plan")
594+
s"Filter/Aggregate/Project/DeleteFromTable/UpdateTable: $plan")
594595
}
595596
}
596597

597598
case inSubqueryOrExistsSubquery =>
598599
plan match {
599-
case _: Filter | _: DeleteFromTable => // Ok
600+
case _: Filter | _: DeleteFromTable | _: UpdateTable => // Ok
600601
case _ =>
601602
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
602-
s" Filter/DeleteFromTable: $plan")
603+
s" Filter/DeleteFromTable/UpdateTable: $plan")
603604
}
604605
}
605606

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
3636
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
3737
import org.apache.spark.sql.catalyst.plans._
3838
import org.apache.spark.sql.catalyst.plans.logical._
39-
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DeleteFromStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement, ShowNamespacesStatement, ShowTablesStatement}
39+
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DeleteFromStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement, ShowNamespacesStatement, ShowTablesStatement, UpdateTableStatement}
4040
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
4141
import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
4242
import org.apache.spark.sql.internal.SQLConf
@@ -361,6 +361,36 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
361361
DeleteFromStatement(tableId, tableAlias, predicate)
362362
}
363363

364+
override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) {
365+
val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
366+
val tableAlias = if (ctx.tableAlias() != null) {
367+
val ident = ctx.tableAlias().strictIdentifier()
368+
// We do not allow columns aliases after table alias.
369+
if (ctx.tableAlias().identifierList() != null) {
370+
throw new ParseException("Columns aliases is not allowed in UPDATE.",
371+
ctx.tableAlias().identifierList())
372+
}
373+
if (ident != null) Some(ident.getText) else None
374+
} else {
375+
None
376+
}
377+
val (attrs, values) = ctx.setClause().assign().asScala.map {
378+
kv => visitMultipartIdentifier(kv.key) -> expression(kv.value)
379+
}.unzip
380+
val predicate = if (ctx.whereClause() != null) {
381+
Some(expression(ctx.whereClause().booleanExpression()))
382+
} else {
383+
None
384+
}
385+
386+
UpdateTableStatement(
387+
tableId,
388+
tableAlias,
389+
attrs,
390+
values,
391+
predicate)
392+
}
393+
364394
/**
365395
* Create a partition specification map.
366396
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,21 @@ case class DeleteFromTable(
605605
override def children: Seq[LogicalPlan] = child :: Nil
606606
}
607607

608+
/**
609+
* Update the data of table that specified with the condition with the given updated values.
610+
* NOTE: Considering nested fields, we let the type of `attrs` to be `Seq[Expression]` becuase
611+
* nested fields may be resolved to objects of type like `GetStructField`. However, currently
612+
* we only support top-level fields, in which case `attrs` is actually of type `Seq[Attribute]`.
613+
*/
614+
case class UpdateTable(
615+
child: LogicalPlan,
616+
attrs: Seq[Expression],
617+
values: Seq[Expression],
618+
condition: Option[Expression]) extends Command {
619+
620+
override def children: Seq[LogicalPlan] = child :: Nil
621+
}
622+
608623
/**
609624
* Drop a table.
610625
*/
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.plans.logical.sql
19+
20+
import org.apache.spark.sql.catalyst.expressions.Expression
21+
22+
case class UpdateTableStatement(
23+
tableName: Seq[String],
24+
tableAlias: Option[String],
25+
attrs: Seq[Seq[String]],
26+
values: Seq[Expression],
27+
condition: Option[Expression]) extends ParsedStatement

sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.connector.expressions
1919

2020
import org.apache.spark.sql.catalyst
21+
import org.apache.spark.sql.catalyst.expressions.{Expression => CatalystExpression, Literal => CatalystLiteral}
2122
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
2223
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
@@ -57,6 +58,22 @@ private[sql] object LogicalExpressions {
5758
def days(column: String): DaysTransform = DaysTransform(reference(column))
5859

5960
def hours(column: String): HoursTransform = HoursTransform(reference(column))
61+
62+
/**
63+
* Tries to translate a Catalyst [[CatalystExpression]] into data source [[Expression]].
64+
*
65+
* @return a `Some(catalyst.expression.Expression)` if the input [[CatalystExpression]]
66+
* is convertible, otherwise a `None`.
67+
*/
68+
private[sql] def translateExpression(value: CatalystExpression): Option[Expression] = {
69+
// TODO: currently we only support literal. Will add more public expression in future.
70+
value match {
71+
case l: CatalystLiteral =>
72+
Some(LiteralValue(l.value, l.dataType))
73+
case _ =>
74+
None
75+
}
76+
}
6077
}
6178

6279
/**

0 commit comments

Comments
 (0)