Skip to content

Commit 20acc62

Browse files
committed
HHH-17357 Add hibernate-types module with pgvector support
1 parent c700dcd commit 20acc62

20 files changed

+917
-54
lines changed

docker_db.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,26 @@ postgresql() {
147147
postgresql_12() {
148148
$CONTAINER_CLI rm -f postgres || true
149149
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:12-3.4
150+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-12-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
150151
}
151152

152153
postgresql_13() {
153154
$CONTAINER_CLI rm -f postgres || true
154155
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:13-3.1
156+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-13-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
155157
}
156158

157159
postgresql_14() {
158160
$CONTAINER_CLI rm -f postgres || true
159161
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:14-3.3
162+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-14-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
160163
}
161164

162165
postgresql_15() {
163166
$CONTAINER_CLI rm -f postgres || true
164167
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 --tmpfs /pgtmpfs:size=131072k -d docker.io/postgis/postgis:15-3.3 \
165168
-c fsync=off -c synchronous_commit=off -c full_page_writes=off -c shared_buffers=256MB -c maintenance_work_mem=256MB -c max_wal_size=1GB -c checkpoint_timeout=1d
169+
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-15-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
166170
}
167171

168172
edb() {

hibernate-core/src/main/java/org/hibernate/dialect/function/AvgFunction.java

Lines changed: 130 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
import java.util.Arrays;
1010
import java.util.List;
11+
import java.util.Locale;
12+
import java.util.function.Supplier;
1113

1214
import org.hibernate.dialect.Dialect;
15+
import org.hibernate.metamodel.mapping.BasicValuedMapping;
1316
import org.hibernate.metamodel.mapping.JdbcMapping;
17+
import org.hibernate.metamodel.mapping.MappingModelExpressible;
18+
import org.hibernate.metamodel.model.domain.DomainType;
1419
import org.hibernate.query.ReturnableType;
20+
import org.hibernate.query.sqm.SqmExpressible;
1521
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
1622
import org.hibernate.query.sqm.function.FunctionKind;
17-
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
18-
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
23+
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
24+
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
25+
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
1926
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
20-
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
27+
import org.hibernate.query.sqm.tree.SqmTypedNode;
2128
import org.hibernate.sql.ast.Clause;
2229
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
2330
import org.hibernate.sql.ast.SqlAstTranslator;
@@ -27,8 +34,14 @@
2734
import org.hibernate.sql.ast.tree.expression.Distinct;
2835
import org.hibernate.sql.ast.tree.expression.Expression;
2936
import org.hibernate.sql.ast.tree.predicate.Predicate;
37+
import org.hibernate.type.BasicPluralType;
3038
import org.hibernate.type.BasicType;
39+
import org.hibernate.type.SqlTypes;
3140
import org.hibernate.type.StandardBasicTypes;
41+
import org.hibernate.type.descriptor.java.JavaType;
42+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
43+
import org.hibernate.type.descriptor.jdbc.JdbcType;
44+
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
3245
import org.hibernate.type.spi.TypeConfiguration;
3346

3447
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
@@ -49,10 +62,8 @@ public AvgFunction(
4962
super(
5063
"avg",
5164
FunctionKind.AGGREGATE,
52-
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
53-
StandardFunctionReturnTypeResolvers.invariant(
54-
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
55-
),
65+
new Validator(),
66+
new ReturnTypeResolver( typeConfiguration ),
5667
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, NUMERIC )
5768
);
5869
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
@@ -131,4 +142,116 @@ public String getArgumentListSignature() {
131142
return "(NUMERIC arg)";
132143
}
133144

145+
public static class Validator implements ArgumentsValidator {
146+
147+
public static final ArgumentsValidator INSTANCE = new Validator();
148+
149+
@Override
150+
public void validate(
151+
List<? extends SqmTypedNode<?>> arguments,
152+
String functionName,
153+
TypeConfiguration typeConfiguration) {
154+
if ( arguments.size() != 1 ) {
155+
throw new FunctionArgumentException(
156+
String.format(
157+
Locale.ROOT,
158+
"Function %s() has %d parameters, but %d arguments given",
159+
functionName,
160+
1,
161+
arguments.size()
162+
)
163+
);
164+
}
165+
final SqmTypedNode<?> argument = arguments.get( 0 );
166+
final SqmExpressible<?> expressible = argument.getExpressible();
167+
final DomainType<?> domainType;
168+
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
169+
final JdbcType jdbcType = getJdbcType( domainType, typeConfiguration );
170+
if ( !isNumeric( jdbcType ) ) {
171+
throw new FunctionArgumentException(
172+
String.format(
173+
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
174+
1,
175+
functionName,
176+
NUMERIC,
177+
domainType.getTypeName()
178+
)
179+
);
180+
}
181+
}
182+
}
183+
184+
private static boolean isNumeric(JdbcType jdbcType) {
185+
final int sqlTypeCode = jdbcType.getDefaultSqlTypeCode();
186+
if ( SqlTypes.isNumericType( sqlTypeCode ) ) {
187+
return true;
188+
}
189+
if ( jdbcType instanceof ArrayJdbcType ) {
190+
return isNumeric( ( (ArrayJdbcType) jdbcType ).getElementJdbcType() );
191+
}
192+
return false;
193+
}
194+
195+
private static JdbcType getJdbcType(DomainType<?> domainType, TypeConfiguration typeConfiguration) {
196+
if ( domainType instanceof JdbcMapping ) {
197+
return ( (JdbcMapping) domainType ).getJdbcType();
198+
}
199+
else {
200+
final JavaType<?> javaType = domainType.getExpressibleJavaType();
201+
if ( javaType.getJavaTypeClass().isEnum() ) {
202+
// we can't tell if the enum is mapped STRING or ORDINAL
203+
return ObjectJdbcType.INSTANCE;
204+
}
205+
else {
206+
return javaType.getRecommendedJdbcType( typeConfiguration.getCurrentBaseSqlTypeIndicators() );
207+
}
208+
}
209+
}
210+
211+
@Override
212+
public String getSignature() {
213+
return "(arg)";
214+
}
215+
}
216+
217+
public static class ReturnTypeResolver implements FunctionReturnTypeResolver {
218+
219+
private final BasicType<Double> doubleType;
220+
221+
public ReturnTypeResolver(TypeConfiguration typeConfiguration) {
222+
this.doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
223+
}
224+
225+
@Override
226+
public BasicValuedMapping resolveFunctionReturnType(
227+
Supplier<BasicValuedMapping> impliedTypeAccess,
228+
List<? extends SqlAstNode> arguments) {
229+
final BasicValuedMapping impliedType = impliedTypeAccess.get();
230+
if ( impliedType != null ) {
231+
return impliedType;
232+
}
233+
final JdbcMapping jdbcMapping = ( (Expression) arguments.get( 0 ) ).getExpressionType().getSingleJdbcMapping();
234+
if ( jdbcMapping instanceof BasicPluralType<?, ?> ) {
235+
return (BasicValuedMapping) jdbcMapping;
236+
}
237+
return doubleType;
238+
}
239+
240+
@Override
241+
public ReturnableType<?> resolveFunctionReturnType(
242+
ReturnableType<?> impliedType,
243+
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
244+
List<? extends SqmTypedNode<?>> arguments,
245+
TypeConfiguration typeConfiguration) {
246+
final SqmExpressible<?> expressible = arguments.get( 0 ).getExpressible();
247+
final DomainType<?> domainType;
248+
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
249+
if ( domainType instanceof BasicPluralType<?, ?> ) {
250+
return (ReturnableType<?>) domainType;
251+
}
252+
}
253+
return typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
254+
}
255+
}
256+
134257
}

hibernate-core/src/main/java/org/hibernate/dialect/function/CommonFunctionFactory.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,11 +2009,11 @@ public void aggregates(Dialect dialect, SqlAstNodeRenderingMode inferenceArgumen
20092009
.setExactArgumentCount( 1 )
20102010
.register();
20112011

2012+
20122013
functionRegistry.namedAggregateDescriptorBuilder( "avg" )
20132014
.setArgumentRenderingMode( inferenceArgumentRenderingMode )
2014-
.setInvariantType(doubleType)
2015-
.setExactArgumentCount( 1 )
2016-
.setParameterTypes(NUMERIC)
2015+
.setArgumentsValidator( AvgFunction.Validator.INSTANCE )
2016+
.setReturnTypeResolver( new AvgFunction.ReturnTypeResolver( typeConfiguration ) )
20172017
.register();
20182018

20192019
functionRegistry.register(

hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package org.hibernate.dialect.function;
88

99
import org.hibernate.metamodel.mapping.BasicValuedMapping;
10+
import org.hibernate.metamodel.mapping.JdbcMapping;
1011
import org.hibernate.metamodel.mapping.MappingModelExpressible;
1112
import org.hibernate.query.ReturnableType;
1213
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
@@ -17,12 +18,12 @@
1718

1819
import java.math.BigDecimal;
1920
import java.math.BigInteger;
20-
import java.sql.Types;
2121
import java.util.List;
2222
import java.util.function.Supplier;
2323

2424
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentType;
2525
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentValuedMapping;
26+
import static org.hibernate.type.SqlTypes.*;
2627

2728
/**
2829
* Resolve according to JPA spec 4.8.5
@@ -84,18 +85,20 @@ public ReturnableType<?> resolveFunctionReturnType(
8485
}
8586
}
8687
switch ( basicType.getJdbcType().getDefaultSqlTypeCode() ) {
87-
case Types.SMALLINT:
88-
case Types.TINYINT:
89-
case Types.INTEGER:
90-
case Types.BIGINT:
88+
case SMALLINT:
89+
case TINYINT:
90+
case INTEGER:
91+
case BIGINT:
9192
return longType;
92-
case Types.FLOAT:
93-
case Types.REAL:
94-
case Types.DOUBLE:
93+
case FLOAT:
94+
case REAL:
95+
case DOUBLE:
9596
return doubleType;
96-
case Types.DECIMAL:
97-
case Types.NUMERIC:
97+
case DECIMAL:
98+
case NUMERIC:
9899
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
100+
case VECTOR:
101+
return basicType;
99102
}
100103
return bigDecimalType;
101104
}
@@ -112,22 +115,23 @@ public BasicValuedMapping resolveFunctionReturnType(
112115
}
113116
// Resolve according to JPA spec 4.8.5
114117
final BasicValuedMapping specifiedArgType = extractArgumentValuedMapping( arguments, 1 );
115-
switch ( specifiedArgType.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
116-
case Types.SMALLINT:
117-
case Types.TINYINT:
118-
case Types.INTEGER:
119-
case Types.BIGINT:
118+
final JdbcMapping jdbcMapping = specifiedArgType.getJdbcMapping();
119+
switch ( jdbcMapping.getJdbcType().getDefaultSqlTypeCode() ) {
120+
case SMALLINT:
121+
case TINYINT:
122+
case INTEGER:
123+
case BIGINT:
120124
return longType;
121-
case Types.FLOAT:
122-
case Types.REAL:
123-
case Types.DOUBLE:
125+
case FLOAT:
126+
case REAL:
127+
case DOUBLE:
124128
return doubleType;
125-
case Types.DECIMAL:
126-
case Types.NUMERIC:
127-
final Class<?> argTypeClass = specifiedArgType.getJdbcMapping()
128-
.getJavaTypeDescriptor()
129-
.getJavaTypeClass();
130-
return BigInteger.class.isAssignableFrom(argTypeClass) ? bigIntegerType : bigDecimalType;
129+
case DECIMAL:
130+
case NUMERIC:
131+
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
132+
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
133+
case VECTOR:
134+
return (BasicValuedMapping) jdbcMapping;
131135
}
132136
return bigDecimalType;
133137
}

hibernate-core/src/main/java/org/hibernate/query/sqm/internal/TypecheckUtil.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.hibernate.query.sqm.tree.domain.SqmPluralValuedSimplePath;
2626
import org.hibernate.query.sqm.tree.expression.SqmExpression;
2727
import org.hibernate.query.sqm.tree.expression.SqmLiteralNull;
28+
import org.hibernate.type.BasicPluralType;
2829
import org.hibernate.type.BasicType;
2930
import org.hibernate.type.descriptor.jdbc.JdbcType;
3031

@@ -448,6 +449,14 @@ else if ( Temporal.class.isAssignableFrom( leftJavaType )
448449
+ " (it is not an instance of 'java.lang.Number')" );
449450
}
450451
}
452+
else if ( isNumberArray( leftNodeType ) ) {
453+
// left operand is a number
454+
if ( !isNumberArray( rightNodeType ) ) {
455+
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
456+
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric array type"
457+
+ " (it is not an instance of 'java.lang.Number[]')" );
458+
}
459+
}
451460
else {
452461
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
453462
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
@@ -456,6 +465,16 @@ else if ( Temporal.class.isAssignableFrom( leftJavaType )
456465
}
457466
}
458467

468+
private static boolean isNumberArray(SqmExpressible<?> expressible) {
469+
final DomainType<?> domainType;
470+
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
471+
return domainType instanceof BasicPluralType<?, ?> && Number.class.isAssignableFrom(
472+
( (BasicPluralType<?, ?>) domainType ).getElementType().getJavaType()
473+
);
474+
}
475+
return false;
476+
}
477+
459478
public static void assertString(SqmExpression<?> expression) {
460479
final SqmExpressible<?> nodeType = expression.getNodeType();
461480
if ( nodeType != null ) {

0 commit comments

Comments
 (0)