8
8
9
9
import java .util .Arrays ;
10
10
import java .util .List ;
11
+ import java .util .Locale ;
12
+ import java .util .function .Supplier ;
11
13
12
14
import org .hibernate .dialect .Dialect ;
15
+ import org .hibernate .metamodel .mapping .BasicValuedMapping ;
13
16
import org .hibernate .metamodel .mapping .JdbcMapping ;
17
+ import org .hibernate .metamodel .mapping .MappingModelExpressible ;
18
+ import org .hibernate .metamodel .model .domain .DomainType ;
14
19
import org .hibernate .query .ReturnableType ;
20
+ import org .hibernate .query .sqm .SqmExpressible ;
15
21
import org .hibernate .query .sqm .function .AbstractSqmSelfRenderingFunctionDescriptor ;
16
22
import org .hibernate .query .sqm .function .FunctionKind ;
17
- import org .hibernate .query .sqm .produce .function .ArgumentTypesValidator ;
18
- import org .hibernate .query .sqm .produce .function .StandardArgumentsValidators ;
23
+ import org .hibernate .query .sqm .produce .function .ArgumentsValidator ;
24
+ import org .hibernate .query .sqm .produce .function .FunctionArgumentException ;
25
+ import org .hibernate .query .sqm .produce .function .FunctionReturnTypeResolver ;
19
26
import org .hibernate .query .sqm .produce .function .StandardFunctionArgumentTypeResolvers ;
20
- import org .hibernate .query .sqm .produce . function . StandardFunctionReturnTypeResolvers ;
27
+ import org .hibernate .query .sqm .tree . SqmTypedNode ;
21
28
import org .hibernate .sql .ast .Clause ;
22
29
import org .hibernate .sql .ast .SqlAstNodeRenderingMode ;
23
30
import org .hibernate .sql .ast .SqlAstTranslator ;
27
34
import org .hibernate .sql .ast .tree .expression .Distinct ;
28
35
import org .hibernate .sql .ast .tree .expression .Expression ;
29
36
import org .hibernate .sql .ast .tree .predicate .Predicate ;
37
+ import org .hibernate .type .BasicPluralType ;
30
38
import org .hibernate .type .BasicType ;
39
+ import org .hibernate .type .SqlTypes ;
31
40
import org .hibernate .type .StandardBasicTypes ;
41
+ import org .hibernate .type .descriptor .java .JavaType ;
42
+ import org .hibernate .type .descriptor .jdbc .ArrayJdbcType ;
43
+ import org .hibernate .type .descriptor .jdbc .JdbcType ;
44
+ import org .hibernate .type .descriptor .jdbc .ObjectJdbcType ;
32
45
import org .hibernate .type .spi .TypeConfiguration ;
33
46
34
47
import static org .hibernate .query .sqm .produce .function .FunctionParameterType .NUMERIC ;
@@ -49,10 +62,8 @@ public AvgFunction(
49
62
super (
50
63
"avg" ,
51
64
FunctionKind .AGGREGATE ,
52
- new ArgumentTypesValidator ( StandardArgumentsValidators .exactly ( 1 ), NUMERIC ),
53
- StandardFunctionReturnTypeResolvers .invariant (
54
- typeConfiguration .getBasicTypeRegistry ().resolve ( StandardBasicTypes .DOUBLE )
55
- ),
65
+ new Validator (),
66
+ new ReturnTypeResolver ( typeConfiguration ),
56
67
StandardFunctionArgumentTypeResolvers .invariant ( typeConfiguration , NUMERIC )
57
68
);
58
69
this .defaultArgumentRenderingMode = defaultArgumentRenderingMode ;
@@ -131,4 +142,116 @@ public String getArgumentListSignature() {
131
142
return "(NUMERIC arg)" ;
132
143
}
133
144
145
+ public static class Validator implements ArgumentsValidator {
146
+
147
+ public static final ArgumentsValidator INSTANCE = new Validator ();
148
+
149
+ @ Override
150
+ public void validate (
151
+ List <? extends SqmTypedNode <?>> arguments ,
152
+ String functionName ,
153
+ TypeConfiguration typeConfiguration ) {
154
+ if ( arguments .size () != 1 ) {
155
+ throw new FunctionArgumentException (
156
+ String .format (
157
+ Locale .ROOT ,
158
+ "Function %s() has %d parameters, but %d arguments given" ,
159
+ functionName ,
160
+ 1 ,
161
+ arguments .size ()
162
+ )
163
+ );
164
+ }
165
+ final SqmTypedNode <?> argument = arguments .get ( 0 );
166
+ final SqmExpressible <?> expressible = argument .getExpressible ();
167
+ final DomainType <?> domainType ;
168
+ if ( expressible != null && ( domainType = expressible .getSqmType () ) != null ) {
169
+ final JdbcType jdbcType = getJdbcType ( domainType , typeConfiguration );
170
+ if ( !isNumeric ( jdbcType ) ) {
171
+ throw new FunctionArgumentException (
172
+ String .format (
173
+ "Parameter %d of function '%s()' has type '%s', but argument is of type '%s'" ,
174
+ 1 ,
175
+ functionName ,
176
+ NUMERIC ,
177
+ domainType .getTypeName ()
178
+ )
179
+ );
180
+ }
181
+ }
182
+ }
183
+
184
+ private static boolean isNumeric (JdbcType jdbcType ) {
185
+ final int sqlTypeCode = jdbcType .getDefaultSqlTypeCode ();
186
+ if ( SqlTypes .isNumericType ( sqlTypeCode ) ) {
187
+ return true ;
188
+ }
189
+ if ( jdbcType instanceof ArrayJdbcType ) {
190
+ return isNumeric ( ( (ArrayJdbcType ) jdbcType ).getElementJdbcType () );
191
+ }
192
+ return false ;
193
+ }
194
+
195
+ private static JdbcType getJdbcType (DomainType <?> domainType , TypeConfiguration typeConfiguration ) {
196
+ if ( domainType instanceof JdbcMapping ) {
197
+ return ( (JdbcMapping ) domainType ).getJdbcType ();
198
+ }
199
+ else {
200
+ final JavaType <?> javaType = domainType .getExpressibleJavaType ();
201
+ if ( javaType .getJavaTypeClass ().isEnum () ) {
202
+ // we can't tell if the enum is mapped STRING or ORDINAL
203
+ return ObjectJdbcType .INSTANCE ;
204
+ }
205
+ else {
206
+ return javaType .getRecommendedJdbcType ( typeConfiguration .getCurrentBaseSqlTypeIndicators () );
207
+ }
208
+ }
209
+ }
210
+
211
+ @ Override
212
+ public String getSignature () {
213
+ return "(arg)" ;
214
+ }
215
+ }
216
+
217
+ public static class ReturnTypeResolver implements FunctionReturnTypeResolver {
218
+
219
+ private final BasicType <Double > doubleType ;
220
+
221
+ public ReturnTypeResolver (TypeConfiguration typeConfiguration ) {
222
+ this .doubleType = typeConfiguration .getBasicTypeRegistry ().resolve ( StandardBasicTypes .DOUBLE );
223
+ }
224
+
225
+ @ Override
226
+ public BasicValuedMapping resolveFunctionReturnType (
227
+ Supplier <BasicValuedMapping > impliedTypeAccess ,
228
+ List <? extends SqlAstNode > arguments ) {
229
+ final BasicValuedMapping impliedType = impliedTypeAccess .get ();
230
+ if ( impliedType != null ) {
231
+ return impliedType ;
232
+ }
233
+ final JdbcMapping jdbcMapping = ( (Expression ) arguments .get ( 0 ) ).getExpressionType ().getSingleJdbcMapping ();
234
+ if ( jdbcMapping instanceof BasicPluralType <?, ?> ) {
235
+ return (BasicValuedMapping ) jdbcMapping ;
236
+ }
237
+ return doubleType ;
238
+ }
239
+
240
+ @ Override
241
+ public ReturnableType <?> resolveFunctionReturnType (
242
+ ReturnableType <?> impliedType ,
243
+ Supplier <MappingModelExpressible <?>> inferredTypeSupplier ,
244
+ List <? extends SqmTypedNode <?>> arguments ,
245
+ TypeConfiguration typeConfiguration ) {
246
+ final SqmExpressible <?> expressible = arguments .get ( 0 ).getExpressible ();
247
+ final DomainType <?> domainType ;
248
+ if ( expressible != null && ( domainType = expressible .getSqmType () ) != null ) {
249
+ if ( domainType instanceof BasicPluralType <?, ?> ) {
250
+ return (ReturnableType <?>) domainType ;
251
+ }
252
+ }
253
+ return typeConfiguration .getBasicTypeRegistry ().resolve ( StandardBasicTypes .DOUBLE );
254
+ }
255
+ }
256
+
134
257
}
0 commit comments