Skip to content

Commit b88ec46

Browse files
committed
Add write tests.
1 parent b17ba30 commit b88ec46

File tree

5 files changed

+341
-3
lines changed

5 files changed

+341
-3
lines changed

spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.iceberg.spark;
2121

22+
import com.google.common.collect.Iterables;
2223
import java.util.List;
2324
import java.util.stream.Collectors;
2425
import java.util.stream.IntStream;
@@ -28,6 +29,7 @@
2829
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
2930
import org.apache.spark.sql.Row;
3031
import org.apache.spark.sql.SparkSession;
32+
import org.apache.spark.sql.internal.SQLConf;
3133
import org.junit.AfterClass;
3234
import org.junit.Assert;
3335
import org.junit.BeforeClass;
@@ -49,6 +51,7 @@ public static void startMetastoreAndSpark() {
4951

5052
SparkTestBase.spark = SparkSession.builder()
5153
.master("local[2]")
54+
.config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
5255
.config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname))
5356
.enableHiveSupport()
5457
.getOrCreate();
@@ -79,6 +82,14 @@ protected List<Object[]> sql(String query, Object... args) {
7982
).collect(Collectors.toList());
8083
}
8184

85+
protected Object scalarSql(String query, Object... args) {
86+
List<Object[]> rows = sql(query, args);
87+
Assert.assertEquals("Scalar SQL should return one row", 1, rows.size());
88+
Object[] row = Iterables.getOnlyElement(rows);
89+
Assert.assertEquals("Scalar SQL should return one value", 1, row.length);
90+
return row[0];
91+
}
92+
8293
protected void assertEquals(String context, List<Object[]> expectedRows, List<Object[]> actualRows) {
8394
Assert.assertEquals(context + ": number of results should match", expectedRows.size(), actualRows.size());
8495
for (int i = 0; i < expectedRows.size(); i += 1) {

spark/src/test/java/org/apache/iceberg/spark/source/SimpleRecord.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class SimpleRecord {
2828
public SimpleRecord() {
2929
}
3030

31-
SimpleRecord(Integer id, String data) {
31+
public SimpleRecord(Integer id, String data) {
3232
this.id = id;
3333
this.data = data;
3434
}

spark3/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, Suppo
6666
this.dsSchema = info.schema();
6767
this.options = info.options();
6868
this.overwriteMode = options.containsKey("overwrite-mode") ?
69-
options.get("overwrite-mode").toLowerCase(Locale.ROOT) : null;
69+
options.get("overwrite-mode").toLowerCase(Locale.ROOT) :
70+
spark.sqlContext().conf().partitionOverwriteMode().toString().toLowerCase(Locale.ROOT);
7071
}
7172

7273
private JavaSparkContext lazySparkContext() {
@@ -87,7 +88,17 @@ public WriteBuilder overwriteDynamicPartitions() {
8788
public WriteBuilder overwrite(Filter[] filters) {
8889
this.overwriteExpr = SparkFilters.convert(filters);
8990
if (overwriteExpr == Expressions.alwaysTrue() && "dynamic".equals(overwriteMode)) {
90-
// use the write option to override truncating the table. use dynamic overwrite instead.
91+
// this is a work-around for a Spark bug, where Spark will use a static overwrite expression, alwaysTrue. this
92+
// happens Spark checks whether an INSERT plan should use dynamic overwrite or a static overwrite. Spark uses the
93+
// number of identity partitions instead of the total number of partitions and defaults to static when the number
94+
// of static values provided are equal. if the table has hidden partitions, then it looks like the overwrite
95+
// should be static when there are no static values provided. instead, Spark should rely on the overwrite mode
96+
// when the number of identity partitions and the number of static values is equal.
97+
//
98+
// here, we detect the bug by catching alwaysTrue, which indicates that there were no static partition values.
99+
// there is a slight chance that overwriting the entire table was intended. there are two paths that will result
100+
// in a truncate or overwrite(true): DataFrameWriter with mode overwrite, which was a dynamic overwrite in 2.4,
101+
// and the new DataFrameWriterV2 using overwrite(lit(true)).
91102
this.overwriteDynamic = true;
92103
} else {
93104
Preconditions.checkState(!overwriteDynamic, "Cannot overwrite dynamically and by filter: %s", overwriteExpr);
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iceberg.spark.sql;
21+
22+
import java.util.List;
23+
import java.util.Map;
24+
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
25+
import org.apache.iceberg.spark.SparkCatalogTestBase;
26+
import org.apache.iceberg.spark.source.SimpleRecord;
27+
import org.apache.spark.sql.Dataset;
28+
import org.apache.spark.sql.Row;
29+
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
30+
import org.apache.spark.sql.functions;
31+
import org.junit.After;
32+
import org.junit.Assert;
33+
import org.junit.Before;
34+
import org.junit.Test;
35+
36+
public class TestPartitionedWrites extends SparkCatalogTestBase {
37+
public TestPartitionedWrites(String catalogName, String implementation, Map<String, String> config) {
38+
super(catalogName, implementation, config);
39+
}
40+
41+
@Before
42+
public void createTables() {
43+
sql("CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", tableName);
44+
sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName);
45+
}
46+
47+
@After
48+
public void removeTables() {
49+
sql("DROP TABLE IF EXISTS %s", tableName);
50+
}
51+
52+
@Test
53+
public void testInsertAppend() {
54+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
55+
56+
sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName);
57+
58+
Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName));
59+
60+
List<Object[]> expected = ImmutableList.of(
61+
new Object[] { 1L, "a" },
62+
new Object[] { 2L, "b" },
63+
new Object[] { 3L, "c" },
64+
new Object[] { 4L, "d" },
65+
new Object[] { 5L, "e" }
66+
);
67+
68+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
69+
}
70+
71+
@Test
72+
public void testInsertOverwrite() {
73+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
74+
75+
// 4 and 5 replace 3 in the partition (id - (id % 3)) = 3
76+
sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName);
77+
78+
Assert.assertEquals("Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName));
79+
80+
List<Object[]> expected = ImmutableList.of(
81+
new Object[] { 1L, "a" },
82+
new Object[] { 2L, "b" },
83+
new Object[] { 4L, "d" },
84+
new Object[] { 5L, "e" }
85+
);
86+
87+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
88+
}
89+
90+
@Test
91+
public void testDataFrameV2Append() throws NoSuchTableException {
92+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
93+
94+
List<SimpleRecord> data = ImmutableList.of(
95+
new SimpleRecord(4, "d"),
96+
new SimpleRecord(5, "e")
97+
);
98+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
99+
100+
ds.writeTo(tableName).append();
101+
102+
Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName));
103+
104+
List<Object[]> expected = ImmutableList.of(
105+
new Object[] { 1L, "a" },
106+
new Object[] { 2L, "b" },
107+
new Object[] { 3L, "c" },
108+
new Object[] { 4L, "d" },
109+
new Object[] { 5L, "e" }
110+
);
111+
112+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
113+
}
114+
115+
@Test
116+
public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException {
117+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
118+
119+
List<SimpleRecord> data = ImmutableList.of(
120+
new SimpleRecord(4, "d"),
121+
new SimpleRecord(5, "e")
122+
);
123+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
124+
125+
ds.writeTo(tableName).overwritePartitions();
126+
127+
Assert.assertEquals("Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName));
128+
129+
List<Object[]> expected = ImmutableList.of(
130+
new Object[] { 1L, "a" },
131+
new Object[] { 2L, "b" },
132+
new Object[] { 4L, "d" },
133+
new Object[] { 5L, "e" }
134+
);
135+
136+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
137+
}
138+
139+
@Test
140+
public void testDataFrameV2Overwrite() throws NoSuchTableException {
141+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
142+
143+
List<SimpleRecord> data = ImmutableList.of(
144+
new SimpleRecord(4, "d"),
145+
new SimpleRecord(5, "e")
146+
);
147+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
148+
149+
ds.writeTo(tableName).overwrite(functions.col("id").$less(3));
150+
151+
Assert.assertEquals("Should have 3 rows after overwrite", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
152+
153+
List<Object[]> expected = ImmutableList.of(
154+
new Object[] { 3L, "c" },
155+
new Object[] { 4L, "d" },
156+
new Object[] { 5L, "e" }
157+
);
158+
159+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
160+
}
161+
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iceberg.spark.sql;
21+
22+
import java.util.List;
23+
import java.util.Map;
24+
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
25+
import org.apache.iceberg.spark.SparkCatalogTestBase;
26+
import org.apache.iceberg.spark.source.SimpleRecord;
27+
import org.apache.spark.sql.Dataset;
28+
import org.apache.spark.sql.Row;
29+
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
30+
import org.apache.spark.sql.functions;
31+
import org.junit.After;
32+
import org.junit.Assert;
33+
import org.junit.Before;
34+
import org.junit.Test;
35+
36+
public class TestUnpartitionedWrites extends SparkCatalogTestBase {
37+
public TestUnpartitionedWrites(String catalogName, String implementation, Map<String, String> config) {
38+
super(catalogName, implementation, config);
39+
}
40+
41+
@Before
42+
public void createTables() {
43+
sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName);
44+
sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName);
45+
}
46+
47+
@After
48+
public void removeTables() {
49+
sql("DROP TABLE IF EXISTS %s", tableName);
50+
}
51+
52+
@Test
53+
public void testInsertAppend() {
54+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
55+
56+
sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName);
57+
58+
Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName));
59+
60+
List<Object[]> expected = ImmutableList.of(
61+
new Object[] { 1L, "a" },
62+
new Object[] { 2L, "b" },
63+
new Object[] { 3L, "c" },
64+
new Object[] { 4L, "d" },
65+
new Object[] { 5L, "e" }
66+
);
67+
68+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
69+
}
70+
71+
@Test
72+
public void testInsertOverwrite() {
73+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
74+
75+
sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName);
76+
77+
Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName));
78+
79+
List<Object[]> expected = ImmutableList.of(
80+
new Object[] { 4L, "d" },
81+
new Object[] { 5L, "e" }
82+
);
83+
84+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
85+
}
86+
87+
@Test
88+
public void testDataFrameV2Append() throws NoSuchTableException {
89+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
90+
91+
List<SimpleRecord> data = ImmutableList.of(
92+
new SimpleRecord(4, "d"),
93+
new SimpleRecord(5, "e")
94+
);
95+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
96+
97+
ds.writeTo(tableName).append();
98+
99+
Assert.assertEquals("Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName));
100+
101+
List<Object[]> expected = ImmutableList.of(
102+
new Object[] { 1L, "a" },
103+
new Object[] { 2L, "b" },
104+
new Object[] { 3L, "c" },
105+
new Object[] { 4L, "d" },
106+
new Object[] { 5L, "e" }
107+
);
108+
109+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
110+
}
111+
112+
@Test
113+
public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException {
114+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
115+
116+
List<SimpleRecord> data = ImmutableList.of(
117+
new SimpleRecord(4, "d"),
118+
new SimpleRecord(5, "e")
119+
);
120+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
121+
122+
ds.writeTo(tableName).overwritePartitions();
123+
124+
Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName));
125+
126+
List<Object[]> expected = ImmutableList.of(
127+
new Object[] { 4L, "d" },
128+
new Object[] { 5L, "e" }
129+
);
130+
131+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
132+
}
133+
134+
@Test
135+
public void testDataFrameV2Overwrite() throws NoSuchTableException {
136+
Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName));
137+
138+
List<SimpleRecord> data = ImmutableList.of(
139+
new SimpleRecord(4, "d"),
140+
new SimpleRecord(5, "e")
141+
);
142+
Dataset<Row> ds = spark.createDataFrame(data, SimpleRecord.class);
143+
144+
ds.writeTo(tableName).overwrite(functions.col("id").$less$eq(3));
145+
146+
Assert.assertEquals("Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName));
147+
148+
List<Object[]> expected = ImmutableList.of(
149+
new Object[] { 4L, "d" },
150+
new Object[] { 5L, "e" }
151+
);
152+
153+
assertEquals("Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName));
154+
}
155+
}

0 commit comments

Comments
 (0)