Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self serve replication SQL API #226

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
package com.linkedin.openhouse.spark.statementtest;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseParseException;
import java.nio.file.Files;
import lombok.SneakyThrows;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.ExplainMode;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class SetTableReplicationPolicyStatementTest {
private static SparkSession spark = null;

@SneakyThrows
@BeforeAll
public void setupSpark() {
Path unittest = new Path(Files.createTempDirectory("unittest_settablepolicy").toString());
spark =
SparkSession.builder()
.master("local[2]")
.config(
"spark.sql.extensions",
("org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions,"
+ "com.linkedin.openhouse.spark.extensions.OpenhouseSparkSessionExtensions"))
.config("spark.sql.catalog.openhouse", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.openhouse.type", "hadoop")
.config("spark.sql.catalog.openhouse.warehouse", unittest.toString())
.getOrCreate();
}

@Test
public void testSimpleSetReplicationPolicy() {
chenselena marked this conversation as resolved.
Show resolved Hide resolved
String replicationConfigJson = "{\"cluster\":\"a\", \"interval\":\"b\"}";
Dataset<Row> ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', interval:'b'}))");
assert isPlanValid(ds, replicationConfigJson);

// Test support with multiple clusters
replicationConfigJson =
"{\"cluster\":\"a\", \"interval\":\"b\"}, {\"cluster\":\"aa\", \"interval\":\"bb\"}";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a', interval:'b'}, {cluster:'aa', interval:'bb'}))");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for server side, can we add a regex validator for the user provided input? cc: @rohitkum2506

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, the validations will be done on the server side

assert isPlanValid(ds, replicationConfigJson);

// Test with optional interval
replicationConfigJson = "{\"cluster\":\"a\"}";
ds =
spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + "({cluster:'a'}))");
assert isPlanValid(ds, replicationConfigJson);

// Test with optional interval for multiple clusters
replicationConfigJson = "{\"cluster\":\"a\"}, {\"cluster\":\"b\"}";
ds =
spark.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = "
+ "({cluster:'a'}, {cluster:'b'}))");
assert isPlanValid(ds, replicationConfigJson);
}

@Test
public void testReplicationPolicyWithoutProperSyntax() {
// Empty cluster value
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:}))")
.show());

// Empty interval value
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interval:}))")
.show());

// Empty interval value
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interval:}))")
.show());

// Missing cluster value but interval present
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval: 'bb'}))")
.show());

// Missing interval value but keyword present
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'a', interval:}))")
.show());

// Missing cluster value for multiple clusters
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster:, interval:'a'}, {cluster:, interval: 'b'}))")
.show());

// Missing cluster keyword for multiple clusters
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval:'a'}, {interval: 'b'}))")
.show());

// Missing cluster keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval: 'ss'}))")
.show());

// Typo in keyword interval
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: 'aa', interv: 'ss'}))")
.show());

// Typo in keyword cluster
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({clustr: 'aa', interval: 'ss'}))")
.show());

// Missing quote in cluster value
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({cluster: aa', interval: 'ss}))")
.show());

// Type in REPLICATION keyword
Assertions.assertThrows(
OpenhouseParseException.class,
() ->
spark
.sql(
"ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({cluster: 'aa', interval: 'ss}))")
.show());

// Missing cluster and interval values
Assertions.assertThrows(
OpenhouseParseException.class,
() -> spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({}))").show());
}

@BeforeEach
public void setup() {
spark.sql("CREATE TABLE openhouse.db.table (id bigint, data string) USING iceberg").show();
spark.sql("CREATE TABLE openhouse.0_.0_ (id bigint, data string) USING iceberg").show();
spark
.sql("ALTER TABLE openhouse.db.table SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')")
.show();
spark
.sql("ALTER TABLE openhouse.0_.0_ SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')")
.show();
}

@AfterEach
public void tearDown() {
spark.sql("DROP TABLE openhouse.db.table").show();
spark.sql("DROP TABLE openhouse.0_.0_").show();
}

@AfterAll
public void tearDownSpark() {
spark.close();
}

@SneakyThrows
private boolean isPlanValid(Dataset<Row> dataframe, String replicationConfigJson) {
replicationConfigJson = "[" + replicationConfigJson + "]";
String queryStr = dataframe.queryExecution().explainString(ExplainMode.fromString("simple"));
JsonArray jsonArray = new Gson().fromJson(replicationConfigJson, JsonArray.class);
boolean isValid = false;
for (JsonElement element : jsonArray) {
JsonObject entry = element.getAsJsonObject();
String cluster = entry.get("cluster").getAsString();
isValid = queryStr.contains(cluster);
if (entry.has("interval")) {
String interval = entry.get("interval").getAsString();
isValid = queryStr.contains(cluster) && queryStr.contains(interval);
}
}
return isValid;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ singleStatement

statement
: ALTER TABLE multipartIdentifier SET POLICY '(' retentionPolicy (columnRetentionPolicy)? ')' #setRetentionPolicy
| ALTER TABLE multipartIdentifier SET POLICY '(' replicationPolicy ')' #setReplicationPolicy
| ALTER TABLE multipartIdentifier SET POLICY '(' sharingPolicy ')' #setSharingPolicy
| ALTER TABLE multipartIdentifier MODIFY columnNameClause SET columnPolicy #setColumnPolicyTag
| GRANT privilege ON grantableResource TO principal #grantStatement
Expand Down Expand Up @@ -64,7 +65,7 @@ quotedIdentifier
;

nonReserved
: ALTER | TABLE | SET | POLICY | RETENTION | SHARING
: ALTER | TABLE | SET | POLICY | RETENTION | SHARING | REPLICATION
| GRANT | REVOKE | ON | TO | SHOW | GRANTS | PATTERN | WHERE | COLUMN
;

Expand All @@ -83,6 +84,26 @@ columnRetentionPolicy
: ON columnNameClause (columnRetentionPolicyPatternClause)?
;

replicationPolicy
: REPLICATION '=' tableReplicationPolicy
;

tableReplicationPolicy
: '(' replicationPolicyClause (',' replicationPolicyClause)* ')'
;

replicationPolicyClause
: '{' replicationPolicyClusterClause (',' replicationPolicyIntervalClause)? '}'
;

replicationPolicyClusterClause
: CLUSTER ':' STRING
;

replicationPolicyIntervalClause
: INTERVAL ':' STRING
;

columnRetentionPolicyPatternClause
: WHERE retentionColumnPatternClause
;
Expand Down Expand Up @@ -136,6 +157,7 @@ TABLE: 'TABLE';
SET: 'SET';
POLICY: 'POLICY';
RETENTION: 'RETENTION';
REPLICATION: 'REPLICATION';
SHARING: 'SHARING';
GRANT: 'GRANT';
REVOKE: 'REVOKE';
Expand All @@ -150,6 +172,8 @@ DATABASE: 'DATABASE';
SHOW: 'SHOW';
GRANTS: 'GRANTS';
PATTERN: 'PATTERN';
CLUSTER: 'CLUSTER';
INTERVAL: 'INTERVAL';
WHERE: 'WHERE';
COLUMN: 'COLUMN';
PII: 'PII';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package com.linkedin.openhouse.spark.sql.catalyst.parser.extensions

import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes
import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseSqlExtensionsParser._
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes.GrantableResourceType
import com.linkedin.openhouse.gen.tables.client.model.TimePartitionSpec
import org.antlr.v4.runtime.tree.ParseTree
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.collection.JavaConversions.iterableAsScalaIterable
import scala.collection.JavaConverters._

class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends OpenhouseSqlExtensionsBaseVisitor[AnyRef] {
Expand All @@ -26,6 +27,12 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
SetRetentionPolicy(tableName, granularity, count, Option(colName), Option(colPattern))
}

override def visitSetReplicationPolicy(ctx: SetReplicationPolicyContext): SetReplicationPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val replicationPolicies = typedVisit[Seq[(String, Option[String])]](ctx.replicationPolicy())
SetReplicationPolicy(tableName, replicationPolicies)
}

override def visitSetSharingPolicy(ctx: SetSharingPolicyContext): SetSharingPolicy = {
val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier)
val sharing = typedVisit[String](ctx.sharingPolicy())
Expand Down Expand Up @@ -86,6 +93,30 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh
typedVisit[(String, Int)](ctx.duration())
}

override def visitReplicationPolicy(ctx: ReplicationPolicyContext): Seq[(String, Option[String])] = {
typedVisit[Seq[(String, Option[String])]](ctx.tableReplicationPolicy())
}

override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): Seq[(String, Option[String])] = {
toSeq(ctx.replicationPolicyClause()).map(typedVisit[(String, Option[String])])
}

override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): (String, Option[String]) = {
val cluster = typedVisit[String](ctx.replicationPolicyClusterClause())
val interval = if (ctx.replicationPolicyIntervalClause() != null)
typedVisit[String](ctx.replicationPolicyIntervalClause())
else null
(cluster, Option(interval))
}

override def visitReplicationPolicyClusterClause(ctx: ReplicationPolicyClusterClauseContext): (String) = {
ctx.STRING().getText
}

override def visitReplicationPolicyIntervalClause(ctx: ReplicationPolicyIntervalClauseContext): (String) = {
ctx.STRING().getText
}

override def visitColumnRetentionPolicy(ctx: ColumnRetentionPolicyContext): (String, String) = {
if (ctx.columnRetentionPolicyPatternClause() != null) {
(ctx.columnNameClause().identifier().getText(), ctx.columnRetentionPolicyPatternClause().retentionColumnPatternClause().STRING().getText)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.linkedin.openhouse.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.plans.logical.Command

case class SetReplicationPolicy(tableName: Seq[String], replicationPolicies: Seq[(String, Option[String])]) extends Command {
chenselena marked this conversation as resolved.
Show resolved Hide resolved
override def simpleString(maxFields: Int): String = {
s"SetReplicationPolicy: ${tableName} ${replicationPolicies}"
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.linkedin.openhouse.spark.sql.execution.datasources.v2

import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement}
import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement}
import org.apache.iceberg.spark.{Spark3Util, SparkCatalog, SparkSessionCatalog}
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
Expand All @@ -15,6 +15,8 @@ case class OpenhouseDataSourceV2Strategy(spark: SparkSession) extends Strategy w
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case SetRetentionPolicy(CatalogAndIdentifierExtractor(catalog, ident), granularity, count, colName, colPattern) =>
SetRetentionPolicyExec(catalog, ident, granularity, count, colName, colPattern) :: Nil
case SetReplicationPolicy(CatalogAndIdentifierExtractor(catalog, ident), replicationPolicies) =>
SetReplicationPolicyExec(catalog, ident, replicationPolicies) :: Nil
case SetSharingPolicy(CatalogAndIdentifierExtractor(catalog, ident), sharing) =>
SetSharingPolicyExec(catalog, ident, sharing) :: Nil
case SetColumnPolicyTag(CatalogAndIdentifierExtractor(catalog, ident), policyTag, cols) =>
Expand Down
Loading
Loading