Skip to content

Commit 6b004ed

Browse files
authored
Pulsar SQL support for Decimal data type (#15153)
1 parent b083e9a commit 6b004ed

File tree

9 files changed

+96
-4
lines changed

9 files changed

+96
-4
lines changed

pulsar-sql/presto-pulsar/src/main/java/org/apache/pulsar/sql/presto/decoder/avro/PulsarAvroColumnDecoder.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
import io.prestosql.spi.type.BigintType;
4141
import io.prestosql.spi.type.BooleanType;
4242
import io.prestosql.spi.type.DateType;
43+
import io.prestosql.spi.type.DecimalType;
44+
import io.prestosql.spi.type.Decimals;
4345
import io.prestosql.spi.type.DoubleType;
4446
import io.prestosql.spi.type.IntegerType;
4547
import io.prestosql.spi.type.MapType;
@@ -53,6 +55,7 @@
5355
import io.prestosql.spi.type.Type;
5456
import io.prestosql.spi.type.VarbinaryType;
5557
import io.prestosql.spi.type.VarcharType;
58+
import java.math.BigInteger;
5659
import java.nio.ByteBuffer;
5760
import java.util.List;
5861
import java.util.Map;
@@ -139,7 +142,7 @@ private boolean isSupportedType(Type type) {
139142
}
140143

141144
private boolean isSupportedPrimitive(Type type) {
142-
return type instanceof VarcharType || SUPPORTED_PRIMITIVE_TYPES.contains(type);
145+
return type instanceof VarcharType || type instanceof DecimalType || SUPPORTED_PRIMITIVE_TYPES.contains(type);
143146
}
144147

145148
public FieldValueProvider decodeField(GenericRecord avroRecord) {
@@ -205,6 +208,13 @@ public long getLong() {
205208
return floatToIntBits((Float) value);
206209
}
207210

211+
if (columnType instanceof DecimalType) {
212+
ByteBuffer buffer = (ByteBuffer) value;
213+
byte[] bytes = new byte[buffer.remaining()];
214+
buffer.get(bytes);
215+
return new BigInteger(bytes).longValue();
216+
}
217+
208218
throw new PrestoException(DECODER_CONVERSION_NOT_SUPPORTED,
209219
format("cannot decode object of '%s' as '%s' for column '%s'",
210220
value.getClass(), columnType, columnName));
@@ -234,6 +244,13 @@ private static Slice getSlice(Object value, Type type, String columnName) {
234244
}
235245
}
236246

247+
// The returned Slice size must be equals to 18 Byte
248+
if (type instanceof DecimalType) {
249+
ByteBuffer buffer = (ByteBuffer) value;
250+
BigInteger bigInteger = new BigInteger(buffer.array());
251+
return Decimals.encodeUnscaledValue(bigInteger);
252+
}
253+
237254
throw new PrestoException(DECODER_CONVERSION_NOT_SUPPORTED,
238255
format("cannot decode object of '%s' as '%s' for column '%s'",
239256
value.getClass(), type, columnName));

pulsar-sql/presto-pulsar/src/main/java/org/apache/pulsar/sql/presto/decoder/avro/PulsarAvroRowDecoderFactory.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.prestosql.spi.type.ArrayType;
3434
import io.prestosql.spi.type.BigintType;
3535
import io.prestosql.spi.type.BooleanType;
36+
import io.prestosql.spi.type.DecimalType;
3637
import io.prestosql.spi.type.DoubleType;
3738
import io.prestosql.spi.type.IntegerType;
3839
import io.prestosql.spi.type.RealType;
@@ -128,7 +129,14 @@ private Type parseAvroPrestoType(String fieldname, Schema schema) {
128129
+ "please check the schema or report the bug.", fieldname));
129130
case FIXED:
130131
case BYTES:
131-
//TODO: support decimal logicalType
132+
// When the precision <= 0, throw Exception.
133+
// When the precision > 0 and <= 18, use ShortDecimalType. and mapping Long
134+
// When the precision > 18 and <= 36, use LongDecimalType. and mapping Slice
135+
// When the precision > 36, throw Exception.
136+
if (logicalType instanceof LogicalTypes.Decimal) {
137+
LogicalTypes.Decimal decimal = (LogicalTypes.Decimal) logicalType;
138+
return DecimalType.createDecimalType(decimal.getPrecision(), decimal.getScale());
139+
}
132140
return VarbinaryType.VARBINARY;
133141
case INT:
134142
if (logicalType == LogicalTypes.timeMillis()) {

pulsar-sql/presto-pulsar/src/main/java/org/apache/pulsar/sql/presto/decoder/json/PulsarJsonRowDecoderFactory.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ private Type parseJsonPrestoType(String fieldname, Schema schema) {
128128
+ "please check the schema or report the bug.", fieldname));
129129
case FIXED:
130130
case BYTES:
131+
// In the current implementation, since JsonSchema is generated by Avro,
132+
// there may exist LogicalTypes.Decimal.
133+
// Mapping decimalType with varcharType in JsonSchema.
134+
if (logicalType instanceof LogicalTypes.Decimal) {
135+
return createUnboundedVarcharType();
136+
}
131137
return VarbinaryType.VARBINARY;
132138
case INT:
133139
if (logicalType == LogicalTypes.timeMillis()) {

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/TestPulsarConnector.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.prestosql.spi.connector.ConnectorContext;
2626
import io.prestosql.spi.predicate.TupleDomain;
2727
import io.prestosql.testing.TestingConnectorContext;
28+
import java.math.BigDecimal;
2829
import org.apache.bookkeeper.mledger.AsyncCallbacks;
2930
import org.apache.bookkeeper.mledger.Entry;
3031
import org.apache.bookkeeper.mledger.ManagedLedgerConfig;
@@ -166,6 +167,8 @@ public enum TestEnum {
166167
public int time;
167168
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"int\", \"logicalType\": \"date\" }")
168169
public int date;
170+
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 4, \"scale\": 2 }")
171+
public BigDecimal decimal;
169172
public TestPulsarConnector.Bar bar;
170173
public TestEnum field7;
171174
}
@@ -253,6 +256,7 @@ public static class Bar {
253256
fooFieldNames.add("date");
254257
fooFieldNames.add("bar");
255258
fooFieldNames.add("field7");
259+
fooFieldNames.add("decimal");
256260

257261

258262
ConnectorContext prestoConnectorContext = new TestingConnectorContext();
@@ -313,6 +317,7 @@ public static class Bar {
313317
LocalDate epoch = LocalDate.ofEpochDay(0);
314318
return Math.toIntExact(ChronoUnit.DAYS.between(epoch, localDate));
315319
});
320+
fooFunctions.put("decimal", integer -> BigDecimal.valueOf(1234, 2));
316321
fooFunctions.put("bar.field1", integer -> integer % 3 == 0 ? null : integer + 1);
317322
fooFunctions.put("bar.field2", integer -> integer % 2 == 0 ? null : String.valueOf(integer + 2));
318323
fooFunctions.put("bar.field3", integer -> integer + 3.0f);
@@ -331,7 +336,6 @@ public static class Bar {
331336
* @param schemaInfo
332337
* @param handleKeyValueType
333338
* @param includeInternalColumn
334-
* @param dispatchingRowDecoderFactory
335339
* @return
336340
*/
337341
protected static List<PulsarColumnHandle> getColumnColumnHandles(TopicName topicName, SchemaInfo schemaInfo,
@@ -393,6 +397,7 @@ private static List<Entry> getTopicEntries(String topicSchemaName) {
393397
LocalDate localDate = LocalDate.now();
394398
LocalDate epoch = LocalDate.ofEpochDay(0);
395399
foo.date = Math.toIntExact(ChronoUnit.DAYS.between(epoch, localDate));
400+
foo.decimal= BigDecimal.valueOf(count, 2);
396401

397402
MessageMetadata messageMetadata = new MessageMetadata()
398403
.setProducerName("test-producer").setSequenceId(i)
@@ -609,6 +614,7 @@ public void run() {
609614
foo.timestamp = (long) fooFunctions.get("timestamp").apply(count);
610615
foo.time = (int) fooFunctions.get("time").apply(count);
611616
foo.date = (int) fooFunctions.get("date").apply(count);
617+
foo.decimal = (BigDecimal) fooFunctions.get("decimal").apply(count);
612618
foo.bar = bar;
613619
foo.field7 = (Foo.TestEnum) fooFunctions.get("field7").apply(count);
614620

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/TestPulsarRecordCursor.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
import io.airlift.log.Logger;
2323
import io.netty.buffer.ByteBuf;
2424
import io.prestosql.spi.predicate.TupleDomain;
25+
import io.prestosql.spi.type.DecimalType;
2526
import io.prestosql.spi.type.RowType;
27+
import io.prestosql.spi.type.Type;
28+
import io.prestosql.spi.type.VarcharType;
29+
import java.math.BigDecimal;
2630
import lombok.Data;
2731
import org.apache.bookkeeper.mledger.AsyncCallbacks;
2832
import org.apache.bookkeeper.mledger.Entry;
@@ -142,6 +146,17 @@ public void testTopics() throws Exception {
142146
}else if (fooColumnHandles.get(i).getName().equals("field7")) {
143147
assertEquals(pulsarRecordCursor.getSlice(i).getBytes(), fooFunctions.get("field7").apply(count).toString().getBytes());
144148
columnsSeen.add(fooColumnHandles.get(i).getName());
149+
}else if (fooColumnHandles.get(i).getName().equals("decimal")) {
150+
Type type = fooColumnHandles.get(i).getType();
151+
// In JsonDecoder, decimal trans to varcharType
152+
if (type instanceof VarcharType) {
153+
assertEquals(new String(pulsarRecordCursor.getSlice(i).getBytes()),
154+
fooFunctions.get("decimal").apply(count).toString());
155+
} else {
156+
DecimalType decimalType = (DecimalType) fooColumnHandles.get(i).getType();
157+
assertEquals(BigDecimal.valueOf(pulsarRecordCursor.getLong(i), decimalType.getScale()), fooFunctions.get("decimal").apply(count));
158+
}
159+
columnsSeen.add(fooColumnHandles.get(i).getName());
145160
} else {
146161
if (PulsarInternalColumn.getInternalFieldsMap().containsKey(fooColumnHandles.get(i).getName())) {
147162
columnsSeen.add(fooColumnHandles.get(i).getName());

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/decoder/AbstractDecoderTester.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import io.prestosql.spi.connector.ConnectorContext;
2727
import io.prestosql.spi.type.Type;
2828
import io.prestosql.testing.TestingConnectorContext;
29+
import java.math.BigDecimal;
2930
import org.apache.pulsar.common.naming.NamespaceName;
3031
import org.apache.pulsar.common.naming.TopicName;
3132
import org.apache.pulsar.common.schema.SchemaInfo;
@@ -102,6 +103,10 @@ protected void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRo
102103
decoderTestUtil.checkValue(decodedRow, handle, value);
103104
}
104105

106+
protected void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle, BigDecimal value) {
107+
decoderTestUtil.checkValue(decodedRow, handle, value);
108+
}
109+
105110
protected Block getBlock(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle) {
106111
FieldValueProvider provider = decodedRow.get(handle);
107112
assertNotNull(provider);

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/decoder/DecoderTestMessage.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
package org.apache.pulsar.sql.presto.decoder;
2020

21+
import java.math.BigDecimal;
2122
import lombok.Data;
2223

2324
import java.util.List;
@@ -45,6 +46,10 @@ public static enum TestEnum {
4546
public int dateField;
4647
public TestRow rowField;
4748
public TestEnum enumField;
49+
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 4, \"scale\": 2 }")
50+
public BigDecimal decimalField;
51+
@org.apache.avro.reflect.AvroSchema("{ \"type\": \"bytes\", \"logicalType\": \"decimal\", \"precision\": 30, \"scale\": 2 }")
52+
public BigDecimal longDecimalField;
4853

4954
public List<String> arrayField;
5055
public Map<String, Long> mapField;
@@ -62,7 +67,6 @@ public static class NestedRow {
6267
public long longField;
6368
}
6469

65-
6670
public static class CompositeRow {
6771
public String stringField;
6872
public List<NestedRow> arrayField;

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/decoder/DecoderTestUtil.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@
2323
import io.prestosql.decoder.FieldValueProvider;
2424
import io.prestosql.spi.block.Block;
2525
import io.prestosql.spi.type.ArrayType;
26+
import io.prestosql.spi.type.DecimalType;
27+
import io.prestosql.spi.type.Decimals;
2628
import io.prestosql.spi.type.MapType;
2729
import io.prestosql.spi.type.RowType;
2830
import io.prestosql.spi.type.Type;
31+
import java.math.BigDecimal;
32+
import java.math.BigInteger;
2933
import java.util.Map;
3034

35+
import static io.prestosql.spi.type.UnscaledDecimal128Arithmetic.UNSCALED_DECIMAL_128_SLICE_LENGTH;
3136
import static io.prestosql.testing.TestingConnectorSession.SESSION;
3237
import static org.testng.Assert.*;
3338

@@ -113,6 +118,21 @@ public void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow,
113118
assertEquals(provider.getBoolean(), value);
114119
}
115120

121+
public void checkValue(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle, BigDecimal value) {
122+
FieldValueProvider provider = decodedRow.get(handle);
123+
DecimalType decimalType = (DecimalType) handle.getType();
124+
BigDecimal actualDecimal;
125+
if (decimalType.getFixedSize() == UNSCALED_DECIMAL_128_SLICE_LENGTH) {
126+
Slice slice = provider.getSlice();
127+
BigInteger bigInteger = Decimals.decodeUnscaledValue(slice);
128+
actualDecimal = new BigDecimal(bigInteger, decimalType.getScale());
129+
} else {
130+
actualDecimal = BigDecimal.valueOf(provider.getLong(), decimalType.getScale());
131+
}
132+
assertNotNull(provider);
133+
assertEquals(actualDecimal, value);
134+
}
135+
116136
public void checkIsNull(Map<DecoderColumnHandle, FieldValueProvider> decodedRow, DecoderColumnHandle handle) {
117137
FieldValueProvider provider = decodedRow.get(handle);
118138
assertNotNull(provider);

pulsar-sql/presto-pulsar/src/test/java/org/apache/pulsar/sql/presto/decoder/avro/TestAvroDecoder.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
import io.prestosql.spi.PrestoException;
2626
import io.prestosql.spi.type.ArrayType;
2727
import io.prestosql.spi.type.BigintType;
28+
import io.prestosql.spi.type.DecimalType;
2829
import io.prestosql.spi.type.RowType;
2930
import io.prestosql.spi.type.StandardTypes;
3031
import io.prestosql.spi.type.Type;
3132
import io.prestosql.spi.type.TypeSignatureParameter;
3233
import io.prestosql.spi.type.VarcharType;
34+
import java.math.BigDecimal;
3335
import java.util.Arrays;
3436
import java.util.HashMap;
3537
import java.util.HashSet;
@@ -87,6 +89,8 @@ public void testPrimitiveType() {
8789
message.longField = 222L;
8890
message.timestampField = System.currentTimeMillis();
8991
message.enumField = DecoderTestMessage.TestEnum.TEST_ENUM_1;
92+
message.decimalField = BigDecimal.valueOf(2233, 2);
93+
message.longDecimalField = new BigDecimal("1234567891234567891234567891.23");
9094

9195
LocalTime now = LocalTime.now(ZoneId.systemDefault());
9296
message.timeField = now.toSecondOfDay() * 1000;
@@ -127,6 +131,13 @@ public void testPrimitiveType() {
127131
"enumField", VARCHAR, false, false, "enumField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
128132
checkValue(decodedRow, enumFieldColumnHandle, message.enumField.toString());
129133

134+
PulsarColumnHandle decimalFieldColumnHandle = new PulsarColumnHandle(getPulsarConnectorId().toString(),
135+
"decimalField", DecimalType.createDecimalType(4, 2), false, false, "decimalField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
136+
checkValue(decodedRow, decimalFieldColumnHandle, message.decimalField);
137+
138+
PulsarColumnHandle longDecimalFieldColumnHandle = new PulsarColumnHandle(getPulsarConnectorId().toString(),
139+
"longDecimalField", DecimalType.createDecimalType(30, 2), false, false, "longDecimalField", null, null, PulsarColumnHandle.HandleKeyValueType.NONE);
140+
checkValue(decodedRow, longDecimalFieldColumnHandle, message.longDecimalField);
130141
}
131142

132143
@Test

0 commit comments

Comments
 (0)