Skip to content

Commit 9e081b3

Browse files
committed
Rust: Generalize certain type inference logic
1 parent dfe4401 commit 9e081b3

File tree

4 files changed

+143
-91
lines changed

4 files changed

+143
-91
lines changed

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 129 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,6 @@ private Type inferAnnotatedType(AstNode n, TypePath path) {
257257

258258
/** Module for inferring certain type information. */
259259
private module CertainTypeInference {
260-
/** Holds if the type mention does not contain any inferred types `_`. */
261-
predicate typeMentionIsComplete(TypeMention tm) {
262-
not exists(InferTypeRepr t | t.getParentNode*() = tm)
263-
}
264-
265260
/**
266261
* Holds if `ce` is a call where we can infer the type with certainty and if
267262
* `f` is the target of the call and `p` the path invoked by the call.
@@ -343,6 +338,8 @@ private module CertainTypeInference {
343338
let.getPat() = n1 and
344339
let.getInitializer() = n2
345340
)
341+
or
342+
n1 = n2.(ParenExpr).getExpr()
346343
)
347344
or
348345
n1 =
@@ -373,13 +370,57 @@ private module CertainTypeInference {
373370
Type inferCertainType(AstNode n, TypePath path) {
374371
exists(TypeMention tm |
375372
tm = getTypeAnnotation(n) and
376-
typeMentionIsComplete(tm) and
377373
result = tm.resolveTypeAt(path)
378374
)
379375
or
380376
result = inferCertainCallExprType(n, path)
381377
or
382378
result = inferCertainTypeEquality(n, path)
379+
or
380+
result = inferLiteralType(n, path, true)
381+
or
382+
infersCertainTypeAt(n, path, result.getATypeParameter())
383+
}
384+
385+
/**
386+
* Holds if `n` has complete and certain type information at the type path
387+
* `prefix.tp`. This entails that the type at `prefix` must be the type
388+
* that declares `tp`.
389+
*/
390+
pragma[nomagic]
391+
private predicate infersCertainTypeAt(AstNode n, TypePath prefix, TypeParameter tp) {
392+
exists(TypePath path |
393+
exists(inferCertainType(n, path)) and
394+
path.isSnoc(prefix, tp)
395+
)
396+
}
397+
398+
/**
399+
* Holds if `n` has complete and certain type information at _some_ type path.
400+
*/
401+
pragma[nomagic]
402+
predicate hasInferredCertainType(AstNode n) { exists(inferCertainType(n, _)) }
403+
404+
/**
405+
* Holds if `n` having type `t` at `path` conflicts with certain type information.
406+
*/
407+
bindingset[n, path, t]
408+
pragma[inline_late]
409+
predicate certainTypeConflict(AstNode n, TypePath path, Type t) {
410+
inferCertainType(n, path) != t
411+
or
412+
// If we infer that `n` has _some_ type at `T1.T2....Tn`, and we also
413+
// know that `n` certainly has type `certainType` at `T1.T2...Ti`, then
414+
// it must be the case that `T(i+1)` is a type parameter of `certainType`,
415+
// otherwise there is a conflict.
416+
//
417+
// Below, `prefix` is `T1.T2...Ti` and `tp` is `T(i+1)`.
418+
exists(TypePath prefix, TypePath suffix, TypeParameter tp, Type certainType |
419+
path = prefix.appendInverse(suffix) and
420+
tp = suffix.getHead() and
421+
inferCertainType(n, prefix) = certainType and
422+
not certainType.getATypeParameter() = tp
423+
)
383424
}
384425
}
385426

@@ -432,8 +473,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
432473
let.getInitializer() = n2
433474
)
434475
or
435-
n1 = n2.(ParenExpr).getExpr()
436-
or
437476
n1 = n2.(IfExpr).getABranch()
438477
or
439478
n1 = n2.(MatchExpr).getAnArm().getExpr()
@@ -531,9 +570,6 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
531570

532571
pragma[nomagic]
533572
private Type inferTypeEquality(AstNode n, TypePath path) {
534-
// Don't propagate type information into a node for which we already have
535-
// certain type information.
536-
not exists(CertainTypeInference::inferCertainType(n, _)) and
537573
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
538574
result = inferType(n2, prefix2.appendInverse(suffix)) and
539575
path = prefix1.append(suffix)
@@ -1308,17 +1344,22 @@ pragma[nomagic]
13081344
private StructType getStrStruct() { result = TStruct(any(Builtins::Str s)) }
13091345

13101346
pragma[nomagic]
1311-
private Type inferLiteralType(LiteralExpr le, TypePath path) {
1347+
private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) {
13121348
path.isEmpty() and
13131349
exists(Builtins::BuiltinType t | result = TStruct(t) |
13141350
le instanceof CharLiteralExpr and
1315-
t instanceof Builtins::Char
1351+
t instanceof Builtins::Char and
1352+
certain = true
13161353
or
13171354
le =
13181355
any(NumberLiteralExpr ne |
1319-
t.getName() = ne.getSuffix()
1356+
t.getName() = ne.getSuffix() and
1357+
certain = true
13201358
or
1359+
// When a number literal has no suffix, the type may depend on the context.
1360+
// For simplicity, we assume either `i32` or `f64`.
13211361
not exists(ne.getSuffix()) and
1362+
certain = false and
13221363
(
13231364
ne instanceof IntegerLiteralExpr and
13241365
t instanceof Builtins::I32
@@ -1329,7 +1370,8 @@ private Type inferLiteralType(LiteralExpr le, TypePath path) {
13291370
)
13301371
or
13311372
le instanceof BooleanLiteralExpr and
1332-
t instanceof Builtins::Bool
1373+
t instanceof Builtins::Bool and
1374+
certain = true
13331375
)
13341376
or
13351377
le instanceof StringLiteralExpr and
@@ -1338,7 +1380,8 @@ private Type inferLiteralType(LiteralExpr le, TypePath path) {
13381380
or
13391381
path = TypePath::singleton(TRefTypeParameter()) and
13401382
result = getStrStruct()
1341-
)
1383+
) and
1384+
certain = true
13421385
}
13431386

13441387
pragma[nomagic]
@@ -2282,62 +2325,71 @@ private module Cached {
22822325
Stages::TypeInferenceStage::ref() and
22832326
result = CertainTypeInference::inferCertainType(n, path)
22842327
or
2285-
result = inferAnnotatedType(n, path)
2286-
or
2287-
result = inferLogicalOperationType(n, path)
2288-
or
2289-
result = inferAssignmentOperationType(n, path)
2290-
or
2291-
result = inferTypeEquality(n, path)
2292-
or
2293-
result = inferImplicitSelfType(n, path)
2294-
or
2295-
result = inferStructExprType(n, path)
2296-
or
2297-
result = inferTupleRootType(n) and
2298-
path.isEmpty()
2299-
or
2300-
result = inferPathExprType(n, path)
2301-
or
2302-
result = inferCallExprBaseType(n, path)
2303-
or
2304-
result = inferFieldExprType(n, path)
2305-
or
2306-
result = inferTupleIndexExprType(n, path)
2307-
or
2308-
result = inferTupleContainerExprType(n, path)
2309-
or
2310-
result = inferRefNodeType(n) and
2311-
path.isEmpty()
2312-
or
2313-
result = inferTryExprType(n, path)
2314-
or
2315-
result = inferLiteralType(n, path)
2316-
or
2317-
result = inferAsyncBlockExprRootType(n) and
2318-
path.isEmpty()
2319-
or
2320-
result = inferAwaitExprType(n, path)
2321-
or
2322-
result = inferArrayExprType(n) and
2323-
path.isEmpty()
2324-
or
2325-
result = inferRangeExprType(n) and
2326-
path.isEmpty()
2327-
or
2328-
result = inferIndexExprType(n, path)
2329-
or
2330-
result = inferForLoopExprType(n, path)
2331-
or
2332-
result = inferDynamicCallExprType(n, path)
2333-
or
2334-
result = inferClosureExprType(n, path)
2335-
or
2336-
result = inferCastExprType(n, path)
2337-
or
2338-
result = inferStructPatType(n, path)
2339-
or
2340-
result = inferTupleStructPatType(n, path)
2328+
// Don't propagate type information into a node which conflicts with certain
2329+
// type information.
2330+
(
2331+
if CertainTypeInference::hasInferredCertainType(n)
2332+
then not CertainTypeInference::certainTypeConflict(n, path, result)
2333+
else any()
2334+
) and
2335+
(
2336+
result = inferAnnotatedType(n, path)
2337+
or
2338+
result = inferLogicalOperationType(n, path)
2339+
or
2340+
result = inferAssignmentOperationType(n, path)
2341+
or
2342+
result = inferTypeEquality(n, path)
2343+
or
2344+
result = inferImplicitSelfType(n, path)
2345+
or
2346+
result = inferStructExprType(n, path)
2347+
or
2348+
result = inferTupleRootType(n) and
2349+
path.isEmpty()
2350+
or
2351+
result = inferPathExprType(n, path)
2352+
or
2353+
result = inferCallExprBaseType(n, path)
2354+
or
2355+
result = inferFieldExprType(n, path)
2356+
or
2357+
result = inferTupleIndexExprType(n, path)
2358+
or
2359+
result = inferTupleContainerExprType(n, path)
2360+
or
2361+
result = inferRefNodeType(n) and
2362+
path.isEmpty()
2363+
or
2364+
result = inferTryExprType(n, path)
2365+
or
2366+
result = inferLiteralType(n, path, false)
2367+
or
2368+
result = inferAsyncBlockExprRootType(n) and
2369+
path.isEmpty()
2370+
or
2371+
result = inferAwaitExprType(n, path)
2372+
or
2373+
result = inferArrayExprType(n) and
2374+
path.isEmpty()
2375+
or
2376+
result = inferRangeExprType(n) and
2377+
path.isEmpty()
2378+
or
2379+
result = inferIndexExprType(n, path)
2380+
or
2381+
result = inferForLoopExprType(n, path)
2382+
or
2383+
result = inferDynamicCallExprType(n, path)
2384+
or
2385+
result = inferClosureExprType(n, path)
2386+
or
2387+
result = inferCastExprType(n, path)
2388+
or
2389+
result = inferStructPatType(n, path)
2390+
or
2391+
result = inferTupleStructPatType(n, path)
2392+
)
23412393
}
23422394
}
23432395

@@ -2438,6 +2490,11 @@ private module Debug {
24382490
c = max(countTypePaths(_, _, _))
24392491
}
24402492

2493+
Type debugInferCertainType(AstNode n, TypePath path) {
2494+
n = getRelevantLocatable() and
2495+
result = CertainTypeInference::inferCertainType(n, path)
2496+
}
2497+
24412498
Type debugInferCertainNonUniqueType(AstNode n, TypePath path) {
24422499
n = getRelevantLocatable() and
24432500
Consistency::nonUniqueCertainType(n, path) and

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,7 @@ mod indexers {
20742074
// implicit dereference. We cannot currently handle a position that is
20752075
// both implicitly dereferenced and implicitly borrowed, so the extra
20762076
// type sneaks in.
2077-
let x = slice[0].foo(); // $ target=foo type=x:S target=index SPURIOUS: type=slice:[]
2077+
let x = slice[0].foo(); // $ target=foo type=x:S target=index
20782078
}
20792079

20802080
pub fn f() {

0 commit comments

Comments
 (0)