Skip to content

Commit 1ed4023

Browse files
committed
Improve type safety for TypeParameter
Use a sum type instead of a tag + value. It makes it possible to use deconstruction patterns and switch statements/expressions to manipulate the values of the type.
1 parent 43ffc7d commit 1ed4023

File tree

29 files changed

+272
-294
lines changed

29 files changed

+272
-294
lines changed

core/trino-main/src/main/java/io/trino/metadata/SignatureBinder.java

Lines changed: 80 additions & 103 deletions
Large diffs are not rendered by default.

core/trino-main/src/main/java/io/trino/operator/aggregation/TypeSignatureMapping.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import io.trino.operator.annotations.LiteralImplementationDependency;
2222
import io.trino.operator.annotations.OperatorImplementationDependency;
2323
import io.trino.operator.annotations.TypeImplementationDependency;
24-
import io.trino.spi.type.ParameterKind;
2524
import io.trino.spi.type.TypeParameter;
2625
import io.trino.spi.type.TypeSignature;
2726

@@ -104,8 +103,8 @@ public TypeSignature mapTypeSignature(TypeSignature typeSignature)
104103

105104
private TypeParameter mapTypeSignatureParameter(TypeParameter parameter)
106105
{
107-
if (parameter.getKind() == ParameterKind.TYPE) {
108-
return TypeParameter.typeParameter(parameter.name(), mapTypeSignature(parameter.getTypeSignature()));
106+
if (parameter.value() instanceof TypeParameter.Type(TypeSignature type)) {
107+
return TypeParameter.typeParameter(parameter.name(), mapTypeSignature(type));
109108
}
110109
return parameter;
111110
}

core/trino-main/src/main/java/io/trino/operator/annotations/FunctionsParserHelper.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,8 @@ private static void verifyTypeSignatureDoesNotContainAnyTypeParameters(TypeSigna
176176
checkArgument(!typeParameterNames.contains(typeSignature.getBase()), "Nested type variables are not allowed: %s", rootType);
177177

178178
for (TypeParameter parameter : typeSignature.getParameters()) {
179-
switch (parameter.getKind()) {
180-
case TYPE:
181-
verifyTypeSignatureDoesNotContainAnyTypeParameters(rootType, parameter.getTypeSignature(), typeParameterNames);
182-
break;
183-
case LONG:
184-
case VARIABLE:
185-
break;
186-
default:
187-
throw new UnsupportedOperationException();
179+
if (parameter.value() instanceof TypeParameter.Type(TypeSignature type)) {
180+
verifyTypeSignatureDoesNotContainAnyTypeParameters(rootType, type, typeParameterNames);
188181
}
189182
}
190183
}

core/trino-main/src/main/java/io/trino/operator/annotations/ImplementationDependency.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ static void checkTypeParameters(TypeSignature typeSignature, Set<String> typePar
8787
}
8888

8989
for (TypeParameter parameter : typeSignature.getParameters()) {
90-
if (parameter.isTypeSignature()) {
91-
checkTypeParameters(parameter.getTypeSignature(), typeParameterNames, element);
90+
if (parameter.value() instanceof TypeParameter.Type(TypeSignature type)) {
91+
checkTypeParameters(type, typeParameterNames, element);
9292
}
9393
}
9494
}

core/trino-main/src/main/java/io/trino/operator/scalar/Re2JCastToRegexpFunction.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.trino.spi.function.FunctionMetadata;
2424
import io.trino.spi.function.Signature;
2525
import io.trino.spi.type.Type;
26+
import io.trino.spi.type.TypeParameter;
2627
import io.trino.type.Re2JRegexp;
2728

2829
import java.lang.invoke.MethodHandle;
@@ -73,7 +74,7 @@ private Re2JCastToRegexpFunction(String sourceType, int dfaStatesLimit, int dfaR
7374
protected SpecializedSqlScalarFunction specialize(BoundSignature boundSignature)
7475
{
7576
Type inputType = boundSignature.getArgumentType(0);
76-
Long typeLength = inputType.getTypeSignature().getParameters().get(0).getLongLiteral();
77+
Long typeLength = ((TypeParameter.Numeric) inputType.getTypeSignature().getParameters().get(0).value()).value();
7778
return new ChoicesSpecializedSqlScalarFunction(
7879
boundSignature,
7980
FAIL_ON_NULL,

core/trino-main/src/main/java/io/trino/server/protocol/ProtocolUtil.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,18 @@ private static ClientTypeSignature toClientTypeSignature(TypeSignature signature
149149

150150
private static ClientTypeSignatureParameter toClientTypeSignatureParameter(TypeParameter parameter, boolean supportsParametricDateTime)
151151
{
152-
switch (parameter.getKind()) {
153-
case TYPE:
152+
return switch (parameter.value()) {
153+
case TypeParameter.Type(TypeSignature type) -> {
154154
if (parameter.name().isPresent()) {
155-
return ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(
155+
yield ClientTypeSignatureParameter.ofNamedType(new NamedClientTypeSignature(
156156
parameter.name().map(RowFieldName::new),
157-
toClientTypeSignature(parameter.getTypeSignature(), supportsParametricDateTime)));
157+
toClientTypeSignature(type, supportsParametricDateTime)));
158158
}
159-
return ClientTypeSignatureParameter.ofType(toClientTypeSignature(parameter.getTypeSignature(), supportsParametricDateTime));
160-
case LONG:
161-
return ClientTypeSignatureParameter.ofLong(parameter.getLongLiteral());
162-
case VARIABLE:
163-
// not expected here
164-
}
165-
throw new IllegalArgumentException("Unsupported kind: " + parameter.getKind());
159+
yield ClientTypeSignatureParameter.ofType(toClientTypeSignature(type, supportsParametricDateTime));
160+
}
161+
case TypeParameter.Numeric number -> ClientTypeSignatureParameter.ofLong(number.value());
162+
case TypeParameter.Variable _ -> throw new IllegalArgumentException("Unsupported parameter kind: " + parameter);
163+
};
166164
}
167165

168166
public static StatementStats toStatementStats(ResultQueryInfo queryInfo)

core/trino-main/src/main/java/io/trino/sql/analyzer/TypeSignatureTranslator.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,12 @@ static DataType toDataType(TypeSignature typeSignature)
265265
.map(parameter -> new RowDataType.Field(
266266
Optional.empty(),
267267
parameter.name().map(fieldName -> new Identifier(fieldName, requiresDelimiting(fieldName))),
268-
toDataType(parameter.getTypeSignature()))).collect(toImmutableList()));
268+
toDataType(((TypeParameter.Type) parameter.value()).type()))).collect(toImmutableList()));
269269
case VARCHAR -> new GenericDataType(
270270
Optional.empty(),
271271
new Identifier(typeSignature.getBase(), false),
272272
typeSignature.getParameters().stream()
273-
.filter(parameter -> parameter.getLongLiteral() != UNBOUNDED_LENGTH)
273+
.filter(parameter -> ((TypeParameter.Numeric) parameter.value()).value() != UNBOUNDED_LENGTH)
274274
.map(parameter -> new NumericParameter(Optional.empty(), parameter.toString()))
275275
.collect(toImmutableList()));
276276
default -> new GenericDataType(
@@ -304,9 +304,9 @@ private static boolean isValidIdentifier(String identifier)
304304

305305
private static DataTypeParameter toTypeParameter(TypeParameter parameter)
306306
{
307-
return switch (parameter.getKind()) {
308-
case LONG -> new NumericParameter(Optional.empty(), parameter.toString());
309-
case TYPE -> new io.trino.sql.tree.TypeParameter(toDataType(parameter.getTypeSignature()));
307+
return switch (parameter.value()) {
308+
case TypeParameter.Numeric numeric -> new NumericParameter(Optional.empty(), Long.toString(numeric.value()));
309+
case TypeParameter.Type type -> new io.trino.sql.tree.TypeParameter(toDataType(type.type()));
310310
default -> throw new UnsupportedOperationException("Unsupported parameter kind");
311311
};
312312
}

core/trino-main/src/main/java/io/trino/type/ArrayParametricType.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
package io.trino.type;
1515

1616
import io.trino.spi.type.ArrayType;
17-
import io.trino.spi.type.ParameterKind;
1817
import io.trino.spi.type.ParametricType;
1918
import io.trino.spi.type.StandardTypes;
2019
import io.trino.spi.type.Type;
2120
import io.trino.spi.type.TypeManager;
2221
import io.trino.spi.type.TypeParameter;
22+
import io.trino.spi.type.TypeSignature;
2323

2424
import java.util.List;
2525

@@ -42,10 +42,10 @@ public String getName()
4242
public Type createType(TypeManager typeManager, List<TypeParameter> parameters)
4343
{
4444
checkArgument(parameters.size() == 1, "Array type expects exactly one type as a parameter, got %s", parameters);
45-
checkArgument(
46-
parameters.get(0).getKind() == ParameterKind.TYPE,
47-
"Array expects type as a parameter, got %s",
48-
parameters);
49-
return new ArrayType(typeManager.getType(parameters.get(0).getTypeSignature()));
45+
46+
if (parameters.get(0).value() instanceof TypeParameter.Type(TypeSignature type)) {
47+
return new ArrayType(typeManager.getType(type));
48+
}
49+
throw new IllegalArgumentException("Array expects type as a parameter, got " + parameters);
5050
}
5151
}

core/trino-main/src/main/java/io/trino/type/CharParametricType.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,14 @@ public Type createType(TypeManager typeManager, List<TypeParameter> parameters)
4848

4949
TypeParameter parameter = parameters.get(0);
5050

51-
if (!parameter.isLongLiteral()) {
51+
if (!(parameter.value() instanceof TypeParameter.Numeric(long value))) {
5252
throw new IllegalArgumentException("CHAR length must be a number");
5353
}
5454

55-
long length = parameter.getLongLiteral();
56-
if (length < 0 || length > CharType.MAX_LENGTH) {
57-
throw new IllegalArgumentException("Invalid CHAR length " + length);
55+
if (value < 0 || value > CharType.MAX_LENGTH) {
56+
throw new IllegalArgumentException("Invalid CHAR length " + value);
5857
}
5958

60-
return createCharType(toIntExact(length));
59+
return createCharType(toIntExact(value));
6160
}
6261
}

core/trino-main/src/main/java/io/trino/type/DecimalParametricType.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ public Type createType(TypeManager typeManager, List<TypeParameter> parameters)
3838
{
3939
return switch (parameters.size()) {
4040
case 0 -> DecimalType.createDecimalType();
41-
case 1 -> DecimalType.createDecimalType(parameters.get(0).getLongLiteral().intValue());
42-
case 2 -> DecimalType.createDecimalType(parameters.get(0).getLongLiteral().intValue(), parameters.get(1).getLongLiteral().intValue());
41+
case 1 -> DecimalType.createDecimalType((int) ((TypeParameter.Numeric) parameters.get(0).value()).value());
42+
case 2 -> DecimalType.createDecimalType(
43+
(int) ((TypeParameter.Numeric) parameters.get(0).value()).value(),
44+
(int) ((TypeParameter.Numeric) parameters.get(1).value()).value());
4345
default -> throw new IllegalArgumentException("Expected 0, 1 or 2 parameters for DECIMAL type constructor.");
4446
};
4547
}

0 commit comments

Comments
 (0)