From 69bcf05eff392f29438146db73b3b084c3fc3e29 Mon Sep 17 00:00:00 2001 From: Kyle Bendickson Date: Thu, 18 Aug 2022 09:49:25 -0700 Subject: [PATCH] Spark 3.3: Support bucket in FunctionCatalog (#5513) --- .../org/apache/iceberg/transforms/Bucket.java | 40 +- .../org/apache/iceberg/util/BucketUtil.java | 88 +++++ .../iceberg/transforms/TestBucketing.java | 19 + .../spark/functions/BucketFunction.java | 310 +++++++++++++++ .../spark/functions/SparkFunctions.java | 1 + .../spark/sql/TestSparkBucketFunction.java | 361 ++++++++++++++++++ 6 files changed, 788 insertions(+), 31 deletions(-) create mode 100644 api/src/main/java/org/apache/iceberg/util/BucketUtil.java create mode 100644 spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java create mode 100644 spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java diff --git a/api/src/main/java/org/apache/iceberg/transforms/Bucket.java b/api/src/main/java/org/apache/iceberg/transforms/Bucket.java index ecbefa5cf015..32540bb923d8 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/Bucket.java +++ b/api/src/main/java/org/apache/iceberg/transforms/Bucket.java @@ -22,7 +22,6 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.Set; import java.util.UUID; import org.apache.iceberg.expressions.BoundPredicate; @@ -38,6 +37,7 @@ import org.apache.iceberg.relocated.com.google.common.hash.Hashing; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.BucketUtil; abstract class Bucket implements Transform { private static final HashFunction MURMUR3 = Hashing.murmur3_32_fixed(); @@ -166,7 +166,7 @@ private BucketInteger(int numBuckets) { @Override public int hash(Integer value) { - return MURMUR3.hashLong(value.longValue()).asInt(); + return BucketUtil.hash(value); } @Override @@ -182,7 +182,7 @@ private BucketLong(int numBuckets) { @Override public int hash(Long value) { - return MURMUR3.hashLong(value).asInt(); + return BucketUtil.hash(value); } @Override @@ -202,7 +202,7 @@ static class BucketFloat extends Bucket { @Override public int hash(Float value) { - return MURMUR3.hashLong(Double.doubleToLongBits((double) value)).asInt(); + return BucketUtil.hash(value); } @Override @@ -220,7 +220,7 @@ static class BucketDouble extends Bucket { @Override public int hash(Double value) { - return MURMUR3.hashLong(Double.doubleToLongBits(value)).asInt(); + return BucketUtil.hash(value); } @Override @@ -236,7 +236,7 @@ private BucketString(int numBuckets) { @Override public int hash(CharSequence value) { - return MURMUR3.hashString(value, StandardCharsets.UTF_8).asInt(); + return BucketUtil.hash(value); } @Override @@ -254,24 +254,7 @@ private BucketByteBuffer(int numBuckets) { @Override public int hash(ByteBuffer value) { - if (value.hasArray()) { - return MURMUR3 - .hashBytes( - value.array(), - value.arrayOffset() + value.position(), - value.arrayOffset() + value.remaining()) - .asInt(); - } else { - int position = value.position(); - byte[] copy = new byte[value.remaining()]; - try { - value.get(copy); - } finally { - // make sure the buffer position is unchanged - value.position(position); - } - return MURMUR3.hashBytes(copy).asInt(); - } + return BucketUtil.hash(value); } @Override @@ -287,12 +270,7 @@ private BucketUUID(int numBuckets) { @Override public int hash(UUID value) { - return MURMUR3 - .newHasher(16) - .putLong(Long.reverseBytes(value.getMostSignificantBits())) - .putLong(Long.reverseBytes(value.getLeastSignificantBits())) - .hash() - .asInt(); + return BucketUtil.hash(value); } @Override @@ -308,7 +286,7 @@ private BucketDecimal(int numBuckets) { @Override public int hash(BigDecimal value) { - return MURMUR3.hashBytes(value.unscaledValue().toByteArray()).asInt(); + return BucketUtil.hash(value); } @Override diff --git a/api/src/main/java/org/apache/iceberg/util/BucketUtil.java b/api/src/main/java/org/apache/iceberg/util/BucketUtil.java new file mode 100644 index 000000000000..a9b1ccb24a16 --- /dev/null +++ b/api/src/main/java/org/apache/iceberg/util/BucketUtil.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.util; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.UUID; +import org.apache.iceberg.relocated.com.google.common.hash.HashFunction; +import org.apache.iceberg.relocated.com.google.common.hash.Hashing; + +/** + * Contains the logic for hashing various types for use with the {@code bucket} partition + * transformations + */ +public class BucketUtil { + + private static final HashFunction MURMUR3 = Hashing.murmur3_32_fixed(); + + private BucketUtil() {} + + public static int hash(int value) { + return MURMUR3.hashLong((long) value).asInt(); + } + + public static int hash(long value) { + return MURMUR3.hashLong(value).asInt(); + } + + public static int hash(float value) { + return MURMUR3.hashLong(Double.doubleToLongBits((double) value)).asInt(); + } + + public static int hash(double value) { + return MURMUR3.hashLong(Double.doubleToLongBits(value)).asInt(); + } + + public static int hash(CharSequence value) { + return MURMUR3.hashString(value, StandardCharsets.UTF_8).asInt(); + } + + public static int hash(ByteBuffer value) { + if (value.hasArray()) { + return MURMUR3 + .hashBytes(value.array(), value.arrayOffset() + value.position(), value.remaining()) + .asInt(); + } else { + int position = value.position(); + byte[] copy = new byte[value.remaining()]; + try { + value.get(copy); + } finally { + // make sure the buffer position is unchanged + value.position(position); + } + return MURMUR3.hashBytes(copy).asInt(); + } + } + + public static int hash(UUID value) { + return MURMUR3 + .newHasher(16) + .putLong(Long.reverseBytes(value.getMostSignificantBits())) + .putLong(Long.reverseBytes(value.getLeastSignificantBits())) + .hash() + .asInt(); + } + + public static int hash(BigDecimal value) { + return MURMUR3.hashBytes(value.unscaledValue().toByteArray()).asInt(); + } +} diff --git a/api/src/test/java/org/apache/iceberg/transforms/TestBucketing.java b/api/src/test/java/org/apache/iceberg/transforms/TestBucketing.java index c5bb8c2b2518..fecf8ca97eca 100644 --- a/api/src/test/java/org/apache/iceberg/transforms/TestBucketing.java +++ b/api/src/test/java/org/apache/iceberg/transforms/TestBucketing.java @@ -299,6 +299,25 @@ public void testByteBufferOnHeap() { Assert.assertEquals("Buffer limit should not change", 105, buffer.limit()); } + @Test + public void testByteBufferOnHeapArrayOffset() { + byte[] bytes = randomBytes(128); + ByteBuffer raw = ByteBuffer.wrap(bytes, 5, 100); + ByteBuffer buffer = raw.slice(); + Assert.assertEquals("Buffer arrayOffset should be 5", 5, buffer.arrayOffset()); + + Bucket bucketFunc = Bucket.get(Types.BinaryType.get(), 100); + + Assert.assertEquals( + "HeapByteBuffer hash should match hash for correct slice", + hashBytes(bytes, 5, 100), + bucketFunc.hash(buffer)); + + // verify that the buffer was not modified + Assert.assertEquals("Buffer position should be 0", 0, buffer.position()); + Assert.assertEquals("Buffer limit should not change", 100, buffer.limit()); + } + @Test public void testByteBufferOffHeap() { byte[] bytes = randomBytes(128); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java new file mode 100644 index 000000000000..c21d1315841f --- /dev/null +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.functions; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Set; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.util.BucketUtil; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.catalog.functions.BoundFunction; +import org.apache.spark.sql.connector.catalog.functions.ScalarFunction; +import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A Spark function implementation for the Iceberg bucket transform. + * + *

Example usage: {@code SELECT system.bucket(128, 'abc')}, which returns the bucket 122. + * + *

Note that for performance reasons, the given input number of buckets is not validated in the + * implementations used in code-gen. The number of buckets must be positive to give meaningful + * results. + */ +public class BucketFunction implements UnboundFunction { + private static final int NUM_BUCKETS_ORDINAL = 0; + private static final int VALUE_ORDINAL = 1; + private static final Set SUPPORTED_NUM_BUCKETS_TYPES = + ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + + @Override + public BoundFunction bind(StructType inputType) { + if (inputType.size() != 2) { + throw new UnsupportedOperationException( + "Wrong number of inputs (expected numBuckets and value)"); + } + + StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL]; + StructField valueField = inputType.fields()[VALUE_ORDINAL]; + + if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) { + throw new UnsupportedOperationException( + "Expected number of buckets to be tinyint, shortint or int"); + } + + DataType type = valueField.dataType(); + if (type instanceof DateType) { + return new BucketInt(type); + } else if (type instanceof ByteType + || type instanceof ShortType + || type instanceof IntegerType) { + return new BucketInt(DataTypes.IntegerType); + } else if (type instanceof LongType) { + return new BucketLong(type); + } else if (type instanceof TimestampType) { + return new BucketLong(type); + } else if (type instanceof DecimalType) { + return new BucketDecimal(type); + } else if (type instanceof StringType) { + return new BucketString(); + } else if (type instanceof BinaryType) { + return new BucketBinary(); + } else { + throw new UnsupportedOperationException( + "Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + } + } + + @Override + public String description() { + return name() + + "(numBuckets, col) - Call Iceberg's bucket transform\n" + + " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n" + + " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)"; + } + + @Override + public String name() { + return "bucket"; + } + + public abstract static class BucketBase implements ScalarFunction { + public static int apply(int numBuckets, int hashedValue) { + return (hashedValue & Integer.MAX_VALUE) % numBuckets; + } + + @Override + public String name() { + return "bucket"; + } + + @Override + public DataType resultType() { + return DataTypes.IntegerType; + } + } + + // Used for both int and date - tinyint and smallint are upcasted to int by Spark. + public static class BucketInt extends BucketBase { + private final DataType sqlType; + + // magic method used in codegen + public static int invoke(int numBuckets, int value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(int value) { + return BucketUtil.hash(value); + } + + public BucketInt(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + // return null for null input to match what Spark does in the code-generated versions. + return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL) + ? null + : invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getInt(VALUE_ORDINAL)); + } + } + + // Used for both BigInt and Timestamp + public static class BucketLong extends BucketBase { + private final DataType sqlType; + + // magic function for usage with codegen - needs to be static + public static int invoke(int numBuckets, long value) { + return apply(numBuckets, hash(value)); + } + + // Visible for testing + public static int hash(long value) { + return BucketUtil.hash(value); + } + + public BucketLong(DataType sqlType) { + this.sqlType = sqlType; + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public String canonicalName() { + return String.format("iceberg.bucket(%s)", sqlType.catalogString()); + } + + @Override + public Integer produceResult(InternalRow input) { + return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL) + ? null + : invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getLong(VALUE_ORDINAL)); + } + } + + public static class BucketString extends BucketBase { + // magic function for usage with codegen + public static Integer invoke(int numBuckets, UTF8String value) { + if (value == null) { + return null; + } + + // TODO - We can probably hash the bytes directly given they're already UTF-8 input. + return apply(numBuckets, hash(value.toString())); + } + + // Visible for testing + public static int hash(String value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.StringType}; + } + + @Override + public String canonicalName() { + return "iceberg.bucket(string)"; + } + + @Override + public Integer produceResult(InternalRow input) { + return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL) + ? null + : invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getUTF8String(VALUE_ORDINAL)); + } + } + + public static class BucketBinary extends BucketBase { + public static Integer invoke(int numBuckets, byte[] value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(ByteBuffer.wrap(value))); + } + + // Visible for testing + public static int hash(ByteBuffer value) { + return BucketUtil.hash(value); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, DataTypes.BinaryType}; + } + + @Override + public Integer produceResult(InternalRow input) { + return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL) + ? null + : invoke(input.getInt(NUM_BUCKETS_ORDINAL), input.getBinary(VALUE_ORDINAL)); + } + + @Override + public String canonicalName() { + return "iceberg.bucket(binary)"; + } + } + + public static class BucketDecimal extends BucketBase { + private final DataType sqlType; + private final int precision; + private final int scale; + + // magic method used in codegen + public static Integer invoke(int numBuckets, Decimal value) { + if (value == null) { + return null; + } + + return apply(numBuckets, hash(value.toJavaBigDecimal())); + } + + // Visible for testing + public static int hash(BigDecimal value) { + return BucketUtil.hash(value); + } + + public BucketDecimal(DataType sqlType) { + this.sqlType = sqlType; + this.precision = ((DecimalType) sqlType).precision(); + this.scale = ((DecimalType) sqlType).scale(); + } + + @Override + public DataType[] inputTypes() { + return new DataType[] {DataTypes.IntegerType, sqlType}; + } + + @Override + public Integer produceResult(InternalRow input) { + return input.isNullAt(NUM_BUCKETS_ORDINAL) || input.isNullAt(VALUE_ORDINAL) + ? null + : invoke( + input.getInt(NUM_BUCKETS_ORDINAL), input.getDecimal(VALUE_ORDINAL, precision, scale)); + } + + @Override + public String canonicalName() { + return "iceberg.bucket(decimal)"; + } + } +} diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java index f484a6508bf7..7ca157ad4039 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java @@ -32,6 +32,7 @@ private SparkFunctions() {} private static final Map FUNCTIONS = ImmutableMap.of( "iceberg_version", new IcebergVersionFunction(), + "bucket", new BucketFunction(), "truncate", new TruncateFunction()); private static final List FUNCTION_NAMES = ImmutableList.copyOf(FUNCTIONS.keySet()); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java new file mode 100644 index 000000000000..c9c8c02b417c --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.expressions.Literal; +import org.apache.iceberg.relocated.com.google.common.io.BaseEncoding; +import org.apache.iceberg.spark.SparkTestBaseWithCatalog; +import org.apache.iceberg.spark.functions.BucketFunction; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.types.DataTypes; +import org.assertj.core.api.Assertions; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class TestSparkBucketFunction extends SparkTestBaseWithCatalog { + @Before + public void useCatalog() { + sql("USE %s", catalogName); + } + + @Test + public void testSpecValues() { + Assert.assertEquals( + "Spec example: hash(34) = 2017239379", + 2017239379, + new BucketFunction.BucketInt(DataTypes.IntegerType).hash(34)); + + Assert.assertEquals( + "Spec example: hash(34L) = 2017239379", + 2017239379, + new BucketFunction.BucketLong(DataTypes.LongType).hash(34L)); + + Assert.assertEquals( + "Spec example: hash(decimal2(14.20)) = -500754589", + -500754589, + new BucketFunction.BucketDecimal(DataTypes.createDecimalType(9, 2)) + .hash(new BigDecimal("14.20"))); + + Literal date = Literal.of("2017-11-16").to(Types.DateType.get()); + Assert.assertEquals( + "Spec example: hash(2017-11-16) = -653330422", + -653330422, + new BucketFunction.BucketInt(DataTypes.DateType).hash(date.value())); + + Literal timestampVal = + Literal.of("2017-11-16T22:31:08").to(Types.TimestampType.withoutZone()); + Assert.assertEquals( + "Spec example: hash(2017-11-16T22:31:08) = -2047944441", + -2047944441, + new BucketFunction.BucketLong(DataTypes.TimestampType).hash(timestampVal.value())); + + Assert.assertEquals( + "Spec example: hash(\"iceberg\") = 1210000089", + 1210000089, + new BucketFunction.BucketString().hash("iceberg")); + + ByteBuffer bytes = ByteBuffer.wrap(new byte[] {0, 1, 2, 3}); + Assert.assertEquals( + "Spec example: hash([00 01 02 03]) = -188683207", + -188683207, + new BucketFunction.BucketBinary().hash(bytes)); + } + + @Test + public void testBucketIntegers() { + Assert.assertEquals( + "Byte type should bucket similarly to integer", + 3, + scalarSql("SELECT system.bucket(10, 8Y)")); + Assert.assertEquals( + "Short type should bucket similarly to integer", + 3, + scalarSql("SELECT system.bucket(10, 8S)")); + // Integers + Assert.assertEquals(3, scalarSql("SELECT system.bucket(10, 8)")); + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, 34)")); + Assert.assertNull(scalarSql("SELECT system.bucket(1, CAST(null AS INT))")); + } + + @Test + public void testBucketDates() { + Assert.assertEquals(3, scalarSql("SELECT system.bucket(10, date('1970-01-09'))")); + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, date('1970-02-04'))")); + Assert.assertNull(scalarSql("SELECT system.bucket(1, CAST(null AS DATE))")); + } + + @Test + public void testBucketLong() { + Assert.assertEquals(79, scalarSql("SELECT system.bucket(100, 34L)")); + Assert.assertEquals(76, scalarSql("SELECT system.bucket(100, 0L)")); + Assert.assertEquals(97, scalarSql("SELECT system.bucket(100, -34L)")); + Assert.assertEquals(0, scalarSql("SELECT system.bucket(2, -1L)")); + Assert.assertNull(scalarSql("SELECT system.bucket(2, CAST(null AS LONG))")); + } + + @Test + public void testBucketDecimal() { + Assert.assertEquals(56, scalarSql("SELECT system.bucket(64, CAST('12.34' as DECIMAL(9, 2)))")); + Assert.assertEquals(13, scalarSql("SELECT system.bucket(18, CAST('12.30' as DECIMAL(9, 2)))")); + Assert.assertEquals(2, scalarSql("SELECT system.bucket(16, CAST('12.999' as DECIMAL(9, 3)))")); + Assert.assertEquals(21, scalarSql("SELECT system.bucket(32, CAST('0.05' as DECIMAL(5, 2)))")); + Assert.assertEquals(85, scalarSql("SELECT system.bucket(128, CAST('0.05' as DECIMAL(9, 2)))")); + Assert.assertEquals(3, scalarSql("SELECT system.bucket(18, CAST('0.05' as DECIMAL(9, 2)))")); + + Assert.assertNull( + "Null input should return null", + scalarSql("SELECT system.bucket(2, CAST(null AS decimal))")); + } + + @Test + public void testBucketTimestamp() { + Assert.assertEquals( + 99, scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-01 00:00:00 UTC+00:00')")); + Assert.assertEquals( + 85, scalarSql("SELECT system.bucket(100, TIMESTAMP '1997-01-31 09:26:56 UTC+00:00')")); + Assert.assertEquals( + 62, scalarSql("SELECT system.bucket(100, TIMESTAMP '2022-08-08 00:00:00 UTC+00:00')")); + Assert.assertNull(scalarSql("SELECT system.bucket(2, CAST(null AS timestamp))")); + } + + @Test + public void testBucketString() { + Assert.assertEquals(4, scalarSql("SELECT system.bucket(5, 'abcdefg')")); + Assert.assertEquals(122, scalarSql("SELECT system.bucket(128, 'abc')")); + Assert.assertEquals(54, scalarSql("SELECT system.bucket(64, 'abcde')")); + Assert.assertEquals(8, scalarSql("SELECT system.bucket(12, '测试')")); + Assert.assertEquals(1, scalarSql("SELECT system.bucket(16, '测试raul试测')")); + Assert.assertEquals( + "Varchar should work like string", + 1, + scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS varchar(8)))")); + Assert.assertEquals( + "Char should work like string", + 1, + scalarSql("SELECT system.bucket(16, CAST('测试raul试测' AS char(8)))")); + Assert.assertEquals( + "Should not fail on the empty string", 0, scalarSql("SELECT system.bucket(16, '')")); + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.bucket(16, CAST(null AS string))")); + } + + @Test + public void testBucketBinary() { + Assert.assertEquals( + 1, scalarSql("SELECT system.bucket(10, X'0102030405060708090a0b0c0d0e0f')")); + Assert.assertEquals(10, scalarSql("SELECT system.bucket(12, %s)", asBytesLiteral("abcdefg"))); + Assert.assertEquals(13, scalarSql("SELECT system.bucket(18, %s)", asBytesLiteral("abc\0\0"))); + Assert.assertEquals(42, scalarSql("SELECT system.bucket(48, %s)", asBytesLiteral("abc"))); + Assert.assertEquals(3, scalarSql("SELECT system.bucket(16, %s)", asBytesLiteral("测试_"))); + + Assert.assertNull( + "Null input should return null as output", + scalarSql("SELECT system.bucket(100, CAST(null AS binary))")); + } + + @Test + public void testNumBucketsAcceptsShortAndByte() { + Assert.assertEquals( + "Short types should be usable for the number of buckets field", + 1, + scalarSql("SELECT system.bucket(5S, 1L)")); + + Assert.assertEquals( + "Byte types should be allowed for the number of buckets field", + 1, + scalarSql("SELECT system.bucket(5Y, 1)")); + } + + @Test + public void testWrongNumberOfArguments() { + AssertHelpers.assertThrows( + "Function resolution should not work with zero arguments", + AnalysisException.class, + "Function 'bucket' cannot process input: (): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket()")); + + AssertHelpers.assertThrows( + "Function resolution should not work with only one argument", + AnalysisException.class, + "Function 'bucket' cannot process input: (int): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket(1)")); + + AssertHelpers.assertThrows( + "Function resolution should not work with more than two arguments", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, bigint, int): Wrong number of inputs (expected numBuckets and value)", + () -> scalarSql("SELECT system.bucket(1, 1L, 1)")); + } + + @Test + public void testInvalidTypesCannotBeUsedForNumberOfBuckets() { + AssertHelpers.assertThrows( + "Decimal type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (decimal(9,2), int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(CAST('12.34' as DECIMAL(9, 2)), 10)")); + + AssertHelpers.assertThrows( + "Long type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (bigint, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(12L, 10)")); + + AssertHelpers.assertThrows( + "String type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (string, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket('5', 10)")); + + AssertHelpers.assertThrows( + "Interval year to month type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (interval year to month, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(INTERVAL '100-00' YEAR TO MONTH, 10)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be coercible to the number of buckets", + AnalysisException.class, + "Function 'bucket' cannot process input: (interval day to second, int): Expected number of buckets to be tinyint, shortint or int", + () -> scalarSql("SELECT system.bucket(CAST('11 23:4:0' AS INTERVAL DAY TO SECOND), 10)")); + } + + @Test + public void testInvalidTypesForBucketColumn() { + AssertHelpers.assertThrows( + "Double type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, float): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary", + () -> scalarSql("SELECT system.bucket(10, cast(12.3456 as float))")); + + AssertHelpers.assertThrows( + "Double type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, double): Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary", + () -> scalarSql("SELECT system.bucket(10, cast(12.3456 as double))")); + + AssertHelpers.assertThrows( + "Boolean type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, boolean)", + () -> scalarSql("SELECT system.bucket(10, true)")); + + AssertHelpers.assertThrows( + "Map types should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, map)", + () -> scalarSql("SELECT system.bucket(10, map(1, 1))")); + + AssertHelpers.assertThrows( + "Array types should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, array)", + () -> scalarSql("SELECT system.bucket(10, array(1L))")); + + AssertHelpers.assertThrows( + "Interval year-to-month type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, interval year to month)", + () -> scalarSql("SELECT system.bucket(10, INTERVAL '100-00' YEAR TO MONTH)")); + + AssertHelpers.assertThrows( + "Interval day-time type should not be bucketable", + AnalysisException.class, + "Function 'bucket' cannot process input: (int, interval day to second)", + () -> scalarSql("SELECT system.bucket(10, CAST('11 23:4:0' AS INTERVAL DAY TO SECOND))")); + } + + @Test + public void testThatMagicFunctionsAreInvoked() { + // TinyInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6Y)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // SmallInt + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6S)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Int + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6)")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Date + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, DATE '2022-08-08')")) + .asString() + .isNotNull() + .contains("staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketInt"); + + // Long + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 6L)")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // Timestamp + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(100, TIMESTAMP '2022-08-08')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketLong"); + + // String + Assertions.assertThat(scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, 'abcdefg')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketString"); + + // Decimal + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(5, CAST('12.34' AS DECIMAL))")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketDecimal"); + + // Binary + Assertions.assertThat( + scalarSql("EXPLAIN EXTENDED SELECT system.bucket(4, X'0102030405060708')")) + .asString() + .isNotNull() + .contains( + "staticinvoke(class org.apache.iceberg.spark.functions.BucketFunction$BucketBinary"); + } + + private String asBytesLiteral(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + return "X'" + BaseEncoding.base16().encode(bytes) + "'"; + } +}