Skip to content

Commit a06a902

Browse files
committed
Rust: Handle more explicit type arguments in type inference
1 parent 2e65561 commit a06a902

File tree

12 files changed

+147
-44
lines changed

12 files changed

+147
-44
lines changed

rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ module Impl {
6767
}
6868
}
6969

70-
/** Holds if the call expression dispatches to a trait method. */
70+
/** Holds if the call expression dispatches to a method. */
7171
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
7272
exists(Path path, Function f |
7373
path = call.getFunction().(PathExpr).getPath() and

rust/ql/lib/codeql/rust/internal/PathResolution.qll

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ abstract class ItemNode extends Locatable {
181181
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
182182
or
183183
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
184+
or
185+
result = this.(TypeAliasItemNode).resolveAlias().getASuccessorRec(name) and
186+
not result instanceof TypeParam
184187
}
185188

186189
/**
@@ -289,6 +292,8 @@ abstract class ItemNode extends Locatable {
289292
Location getLocation() { result = super.getLocation() }
290293
}
291294

295+
abstract class TypeItemNode extends ItemNode { }
296+
292297
/** A module or a source file. */
293298
abstract private class ModuleLikeNode extends ItemNode {
294299
/** Gets an item that may refer directly to items defined in this module. */
@@ -438,7 +443,7 @@ private class ConstItemNode extends AssocItemNode instanceof Const {
438443
override TypeParam getTypeParam(int i) { none() }
439444
}
440445

441-
private class EnumItemNode extends ItemNode instanceof Enum {
446+
private class EnumItemNode extends TypeItemNode instanceof Enum {
442447
override string getName() { result = Enum.super.getName().getText() }
443448

444449
override Namespace getNamespace() { result.isType() }
@@ -739,7 +744,7 @@ private class ModuleItemNode extends ModuleLikeNode instanceof Module {
739744
}
740745
}
741746

742-
private class StructItemNode extends ItemNode instanceof Struct {
747+
private class StructItemNode extends TypeItemNode instanceof Struct {
743748
override string getName() { result = Struct.super.getName().getText() }
744749

745750
override Namespace getNamespace() {
@@ -774,7 +779,7 @@ private class StructItemNode extends ItemNode instanceof Struct {
774779
}
775780
}
776781

777-
class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
782+
class TraitItemNode extends ImplOrTraitItemNode, TypeItemNode instanceof Trait {
778783
pragma[nomagic]
779784
Path getABoundPath() {
780785
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
@@ -831,7 +836,10 @@ class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
831836
}
832837
}
833838

834-
class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
839+
class TypeAliasItemNode extends TypeItemNode, AssocItemNode instanceof TypeAlias {
840+
pragma[nomagic]
841+
ItemNode resolveAlias() { result = resolvePathFull(super.getTypeRepr().(PathTypeRepr).getPath()) }
842+
835843
override string getName() { result = TypeAlias.super.getName().getText() }
836844

837845
override predicate hasImplementation() { super.hasTypeRepr() }
@@ -847,7 +855,7 @@ class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
847855
override string getCanonicalPath(Crate c) { none() }
848856
}
849857

850-
private class UnionItemNode extends ItemNode instanceof Union {
858+
private class UnionItemNode extends TypeItemNode instanceof Union {
851859
override string getName() { result = Union.super.getName().getText() }
852860

853861
override Namespace getNamespace() { result.isType() }
@@ -905,7 +913,7 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
905913
override string getCanonicalPath(Crate c) { none() }
906914
}
907915

908-
class TypeParamItemNode extends ItemNode instanceof TypeParam {
916+
class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
909917
private WherePred getAWherePred() {
910918
exists(ItemNode declaringItem |
911919
this = resolveTypeParamPathTypeRepr(result.getTypeRepr()) and

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
1010
private import codeql.rust.frameworks.stdlib.Stdlib
1111
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
1212
private import codeql.rust.elements.Call
13+
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl
1314

1415
class Type = T::Type;
1516

@@ -587,13 +588,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
587588
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
588589
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
589590
or
590-
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
591-
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
591+
exists(ImplOrTraitItemNode i | this = i.getAnAssocItem() |
592+
typeParamMatchPosition(i.getTypeParam(_), result, ppos)
592593
or
593-
ppos.isImplicit() and result = TSelfTypeParameter(trait)
594+
ppos.isImplicit() and result = TSelfTypeParameter(i)
594595
or
595596
ppos.isImplicit() and
596-
result.(AssociatedTypeTypeParameter).getTrait() = trait
597+
result.(AssociatedTypeTypeParameter).getTrait() = i
597598
)
598599
or
599600
ppos.isImplicit() and
@@ -615,6 +616,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
615616
or
616617
result = inferImplicitSelfType(self, path) // `self` parameter without type annotation
617618
)
619+
or
620+
// For associated functions, we may also need to match type arguments against
621+
// the `Self` type. For example, in
622+
//
623+
// ```rust
624+
// struct Foo<T>(T);
625+
//
626+
// impl<T : Default> Foo<T> {
627+
// fn default() -> Self {
628+
// Foo(Default::default())
629+
// }
630+
// }
631+
//
632+
// Foo::<i32>::default();
633+
// ```
634+
//
635+
// we need to match `i32` against the type parameter `T` of the `impl` block.
636+
exists(ImplOrTraitItemNode i |
637+
this = i.getAnAssocItem() and
638+
dpos.isSelf() and
639+
not this.getParamList().hasSelfParam()
640+
|
641+
result = TSelfTypeParameter(i) and
642+
path.isEmpty()
643+
or
644+
result = resolveImplSelfType(i, path)
645+
)
618646
}
619647

620648
private Type resolveRetType(TypePath path) {
@@ -686,6 +714,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
686714

687715
Type getInferredType(AccessPosition apos, TypePath path) {
688716
result = inferType(this.getNodeAt(apos), path)
717+
or
718+
// The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
719+
apos = TArgumentAccessPosition(CallImpl::TSelfArgumentPosition(), false, false) and
720+
exists(PathExpr pe, TypeMention tm |
721+
pe = this.(CallExpr).getFunction() and
722+
tm = pe.getPath().getQualifier() and
723+
result = tm.resolveTypeAt(path)
724+
)
689725
}
690726

691727
Declaration getTarget() {
@@ -1074,12 +1110,7 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) {
10741110
}
10751111

10761112
final class MethodCall extends Call {
1077-
MethodCall() {
1078-
exists(this.getReceiver()) and
1079-
// We want the method calls that don't have a path to a concrete method in
1080-
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
1081-
(this instanceof CallExpr implies exists(this.getTrait()))
1082-
}
1113+
MethodCall() { exists(this.getReceiver()) }
10831114

10841115
/** Gets the type of the receiver of the method call at `path`. */
10851116
Type getTypeAt(TypePath path) {

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ abstract class TypeMention extends AstNode {
1515

1616
/** Gets the sub mention at `path`. */
1717
pragma[nomagic]
18-
TypeMention getMentionAt(TypePath path) {
18+
private TypeMention getMentionAt(TypePath path) {
1919
path.isEmpty() and
2020
result = this
2121
or
@@ -49,31 +49,25 @@ class SliceTypeReprMention extends TypeMention instanceof SliceTypeRepr {
4949
override Type resolveType() { result = TSliceType() }
5050
}
5151

52-
class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
53-
Path path;
54-
ItemNode resolved;
52+
class PathTypeMention extends TypeMention, Path {
53+
TypeItemNode resolved;
5554

56-
PathTypeReprMention() {
57-
path = super.getPath() and
58-
// NOTE: This excludes unresolvable paths which is intentional as these
59-
// don't add value to the type inference anyway.
60-
resolved = resolvePath(path)
61-
}
55+
PathTypeMention() { resolved = resolvePath(this) }
6256

6357
ItemNode getResolved() { result = resolved }
6458

6559
pragma[nomagic]
6660
private TypeAlias getResolvedTraitAlias(string name) {
6761
exists(TraitItemNode trait |
68-
trait = resolvePath(path) and
62+
trait = resolved and
6963
result = trait.getAnAssocItem() and
7064
name = result.getName().getText()
7165
)
7266
}
7367

7468
pragma[nomagic]
7569
private TypeRepr getAssocTypeArg(string name) {
76-
result = path.getSegment().getGenericArgList().getAssocTypeArg(name)
70+
result = this.getSegment().getGenericArgList().getAssocTypeArg(name)
7771
}
7872

7973
/** Gets the type argument for the associated type `alias`, if any. */
@@ -86,11 +80,11 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
8680
}
8781

8882
override TypeMention getTypeArgument(int i) {
89-
result = path.getSegment().getGenericArgList().getTypeArg(i)
83+
result = this.getSegment().getGenericArgList().getTypeArg(i)
9084
or
9185
// If a type argument is not given in the path, then we use the default for
9286
// the type parameter if one exists for the type.
93-
not exists(path.getSegment().getGenericArgList().getTypeArg(i)) and
87+
not exists(this.getSegment().getGenericArgList().getTypeArg(i)) and
9488
result = this.resolveType().getTypeParameterDefault(i)
9589
or
9690
// `Self` paths inside `impl` blocks have implicit type arguments that are
@@ -106,7 +100,7 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
106100
//
107101
// the `Self` return type is shorthand for `Foo<T>`.
108102
exists(ImplItemNode node |
109-
path = node.getASelfPath() and
103+
this = node.getASelfPath() and
110104
result = node.(ImplItemNode).getSelfPath().getSegment().getGenericArgList().getTypeArg(i)
111105
)
112106
or
@@ -124,7 +118,7 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
124118
// ```
125119
// the rhs. of the type alias is a type argument to the trait.
126120
exists(ImplItemNode impl, AssociatedTypeTypeParameter param, TypeAlias alias |
127-
path = impl.getTraitPath() and
121+
this = impl.getTraitPath() and
128122
param.getTrait() = resolved and
129123
alias = impl.getASuccessor(param.getTypeAlias().getName().getText()) and
130124
result = alias.getTypeRepr() and
@@ -142,15 +136,15 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
142136
* resulting type at `typePath`.
143137
*/
144138
pragma[nomagic]
145-
Type aliasResolveTypeAt(TypePath typePath) {
139+
private Type aliasResolveTypeAt(TypePath typePath) {
146140
exists(TypeAlias alias, TypeMention rhs | alias = resolved and rhs = alias.getTypeRepr() |
147141
result = rhs.resolveTypeAt(typePath) and
148142
not result = pathGetTypeParameter(alias, _)
149143
or
150144
exists(TypeParameter tp, TypeMention arg, TypePath prefix, TypePath suffix, int i |
151145
tp = rhs.resolveTypeAt(prefix) and
152146
tp = pathGetTypeParameter(alias, i) and
153-
arg = path.getSegment().getGenericArgList().getTypeArg(i) and
147+
arg = this.getSegment().getGenericArgList().getTypeArg(i) and
154148
result = arg.resolveTypeAt(suffix) and
155149
typePath = prefix.append(suffix)
156150
)
@@ -169,7 +163,7 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
169163
exists(TraitItemNode trait | trait = resolved |
170164
// If this is a `Self` path, then it resolves to the implicit `Self`
171165
// type parameter, otherwise it is a trait bound.
172-
if super.getPath() = trait.getASelfPath()
166+
if this = trait.getASelfPath()
173167
then result = TSelfTypeParameter(trait)
174168
else result = TTrait(trait)
175169
)
@@ -192,6 +186,18 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
192186
}
193187
}
194188

189+
class PathTypeReprMention extends TypeMention, PathTypeRepr {
190+
private PathTypeMention path;
191+
192+
PathTypeReprMention() { path = this.getPath() }
193+
194+
override TypeMention getTypeArgument(int i) { result = path.getTypeArgument(i) }
195+
196+
override Type resolveType() { result = path.resolveType() }
197+
198+
override Type resolveTypeAt(TypePath typePath) { result = path.resolveTypeAt(typePath) }
199+
}
200+
195201
class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr {
196202
override TypeMention getTypeArgument(int i) { none() }
197203

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
illFormedTypeMention
2+
| gen_path_type_repr.rs:6:14:6:20 | ...::Item |
3+
| gen_path_type_repr.rs:6:14:6:20 | ...::Item |

rust/ql/test/extractor-tests/macro-expansion/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
multipleCallTargets
2+
| proc_macro.rs:6:18:6:61 | ...::from(...) |
3+
| proc_macro.rs:7:15:7:58 | ...::from(...) |
4+
| proc_macro.rs:15:5:17:5 | ...::new(...) |
25
| proc_macro.rs:16:12:16:16 | ...::to_tokens(...) |
6+
| proc_macro.rs:22:15:22:58 | ...::from(...) |
7+
| proc_macro.rs:25:5:28:5 | ...::new(...) |
38
| proc_macro.rs:26:10:26:12 | ...::to_tokens(...) |
49
| proc_macro.rs:27:10:27:16 | ...::to_tokens(...) |
10+
| proc_macro.rs:38:15:38:64 | ...::from(...) |
11+
| proc_macro.rs:41:5:49:5 | ...::new(...) |
12+
| proc_macro.rs:41:5:49:5 | ...::new(...) |
13+
| proc_macro.rs:41:5:49:5 | ...::new(...) |
14+
| proc_macro.rs:41:5:49:5 | ...::new(...) |
515
| proc_macro.rs:42:16:42:26 | ...::to_tokens(...) |
616
| proc_macro.rs:44:27:44:30 | ...::to_tokens(...) |
717
| proc_macro.rs:46:18:46:28 | ...::to_tokens(...) |
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
illFormedTypeMention
22
| macro_expansion.rs:99:7:99:19 | MyDeriveUnion |
3+
| macro_expansion.rs:99:7:99:19 | MyDeriveUnion |

rust/ql/test/library-tests/path-resolution/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
multipleCallTargets
22
| main.rs:118:9:118:11 | f(...) |
3+
| proc_macro.rs:6:16:6:59 | ...::from(...) |
4+
| proc_macro.rs:7:19:7:62 | ...::from(...) |
5+
| proc_macro.rs:9:5:11:5 | ...::new(...) |
36
| proc_macro.rs:10:10:10:12 | ...::to_tokens(...) |
47
multiplePathResolutions
58
| main.rs:626:3:626:12 | proc_macro |

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ mod method_supertraits {
792792
if 3 > 2 { // $ method=gt
793793
self.m1() // $ method=MyTrait1::m1
794794
} else {
795-
Self::m1(self)
795+
Self::m1(self) // $ method=MyTrait1::m1
796796
}
797797
}
798798
}
@@ -806,7 +806,7 @@ mod method_supertraits {
806806
if 3 > 2 { // $ method=gt
807807
self.m2().a // $ method=m2 $ fieldof=MyThing
808808
} else {
809-
Self::m2(self).a // $ fieldof=MyThing
809+
Self::m2(self).a // $ method=m2 fieldof=MyThing
810810
}
811811
}
812812
}
@@ -1024,7 +1024,7 @@ mod option_methods {
10241024
struct S;
10251025

10261026
pub fn f() {
1027-
let x1 = MyOption::<S>::new(); // $ MISSING: type=x1:T.S
1027+
let x1 = MyOption::<S>::new(); // $ type=x1:T.S
10281028
println!("{:?}", x1);
10291029

10301030
let mut x2 = MyOption::new();
@@ -1043,7 +1043,7 @@ mod option_methods {
10431043
println!("{:?}", x5.flatten()); // $ method=flatten
10441044

10451045
let x6 = MyOption::MySome(MyOption::<S>::MyNone());
1046-
println!("{:?}", MyOption::<MyOption<S>>::flatten(x6));
1046+
println!("{:?}", MyOption::<MyOption<S>>::flatten(x6)); // $ method=flatten
10471047

10481048
#[rustfmt::skip]
10491049
let from_if = if 3 > 2 { // $ method=gt
@@ -1956,10 +1956,10 @@ mod explicit_type_args {
19561956

19571957
pub fn f() {
19581958
let x1 : Option<S1<S2>> = S1::assoc_fun(); // $ type=x1:T.T.S2
1959-
let x2 = S1::<S2>::assoc_fun(); // $ MISSING: type=x2:T.T.S2
1960-
let x3 = S3::assoc_fun(); // $ MISSING: type=x3:T.T.S2
1961-
let x4 = S1::<S2>::method(S1::default()); // $ MISSING: method=method type=x4:T.S2
1962-
let x5 = S3::method(S1::default()); // $ MISSING: method=method type=x5:T.S2
1959+
let x2 = S1::<S2>::assoc_fun(); // $ type=x2:T.T.S2
1960+
let x3 = S3::assoc_fun(); // $ type=x3:T.T.S2
1961+
let x4 = S1::<S2>::method(S1::default()); // $ method=method type=x4:T.S2
1962+
let x5 = S3::method(S1::default()); // $ method=method type=x5:T.S2
19631963
let x6 = S4::<S2>(S2); // $ type=x6:T4.S2
19641964
let x7 = S4(S2); // $ type=x7:T4.S2
19651965
let x8 = S4(0); // $ type=x8:T4.i32

0 commit comments

Comments
 (0)