Skip to content

Rust: Handle more explicit type arguments in type inference #19847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/elements/internal/CallImpl.qll
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ module Impl {
}
}

/** Holds if the call expression dispatches to a trait method. */
/** Holds if the call expression dispatches to a method. */
private predicate callIsMethodCall(CallExpr call, Path qualifier, string methodName) {
exists(Path path, Function f |
path = call.getFunction().(PathExpr).getPath() and
Expand Down
20 changes: 14 additions & 6 deletions rust/ql/lib/codeql/rust/internal/PathResolution.qll
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ abstract class ItemNode extends Locatable {
result = this.(TypeParamItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(ImplTraitTypeReprItemNode).resolveABound().getASuccessorRec(name).(AssocItemNode)
or
result = this.(TypeAliasItemNode).resolveAlias().getASuccessorRec(name) and
not result instanceof TypeParam
}

/**
Expand Down Expand Up @@ -289,6 +292,8 @@ abstract class ItemNode extends Locatable {
Location getLocation() { result = super.getLocation() }
}

abstract class TypeItemNode extends ItemNode { }

/** A module or a source file. */
abstract private class ModuleLikeNode extends ItemNode {
/** Gets an item that may refer directly to items defined in this module. */
Expand Down Expand Up @@ -438,7 +443,7 @@ private class ConstItemNode extends AssocItemNode instanceof Const {
override TypeParam getTypeParam(int i) { none() }
}

private class EnumItemNode extends ItemNode instanceof Enum {
private class EnumItemNode extends TypeItemNode instanceof Enum {
override string getName() { result = Enum.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -739,7 +744,7 @@ private class ModuleItemNode extends ModuleLikeNode instanceof Module {
}
}

private class StructItemNode extends ItemNode instanceof Struct {
private class StructItemNode extends TypeItemNode instanceof Struct {
override string getName() { result = Struct.super.getName().getText() }

override Namespace getNamespace() {
Expand Down Expand Up @@ -774,7 +779,7 @@ private class StructItemNode extends ItemNode instanceof Struct {
}
}

class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
class TraitItemNode extends ImplOrTraitItemNode, TypeItemNode instanceof Trait {
pragma[nomagic]
Path getABoundPath() {
result = super.getTypeBoundList().getABound().getTypeRepr().(PathTypeRepr).getPath()
Expand Down Expand Up @@ -831,7 +836,10 @@ class TraitItemNode extends ImplOrTraitItemNode instanceof Trait {
}
}

class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
class TypeAliasItemNode extends TypeItemNode, AssocItemNode instanceof TypeAlias {
pragma[nomagic]
ItemNode resolveAlias() { result = resolvePathFull(super.getTypeRepr().(PathTypeRepr).getPath()) }

override string getName() { result = TypeAlias.super.getName().getText() }

override predicate hasImplementation() { super.hasTypeRepr() }
Expand All @@ -847,7 +855,7 @@ class TypeAliasItemNode extends AssocItemNode instanceof TypeAlias {
override string getCanonicalPath(Crate c) { none() }
}

private class UnionItemNode extends ItemNode instanceof Union {
private class UnionItemNode extends TypeItemNode instanceof Union {
override string getName() { result = Union.super.getName().getText() }

override Namespace getNamespace() { result.isType() }
Expand Down Expand Up @@ -905,7 +913,7 @@ private class BlockExprItemNode extends ItemNode instanceof BlockExpr {
override string getCanonicalPath(Crate c) { none() }
}

class TypeParamItemNode extends ItemNode instanceof TypeParam {
class TypeParamItemNode extends TypeItemNode instanceof TypeParam {
private WherePred getAWherePred() {
exists(ItemNode declaringItem |
this = resolveTypeParamPathTypeRepr(result.getTypeRepr()) and
Expand Down
101 changes: 71 additions & 30 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ private import codeql.typeinference.internal.TypeInference
private import codeql.rust.frameworks.stdlib.Stdlib
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
private import codeql.rust.elements.Call
private import codeql.rust.elements.internal.CallImpl::Impl as CallImpl

class Type = T::Type;

Expand Down Expand Up @@ -353,19 +354,6 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
)
}

/**
* Gets any of the types mentioned in `path` that corresponds to the type
* parameter `tp`.
*/
private TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) {
exists(int i |
result = path.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
tp = resolvePath(path).getTypeParam(pragma[only_bind_into](i))
)
or
result = getExplicitTypeArgMention(path.getQualifier(), tp)
}

/**
* A matching configuration for resolving types of struct expressions
* like `Foo { bar = baz }`.
Expand Down Expand Up @@ -452,9 +440,7 @@ private module StructExprMatchingInput implements MatchingInputSig {
class AccessPosition = DeclarationPosition;

class Access extends StructExpr {
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
result = getExplicitTypeArgMention(this.getPath(), apos.asTypeParam()).resolveTypeAt(path)
}
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) { none() }

AstNode getNodeAt(AccessPosition apos) {
result = this.getFieldExpr(apos.asFieldPos()).getExpr()
Expand All @@ -465,6 +451,17 @@ private module StructExprMatchingInput implements MatchingInputSig {

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The struct type is supplied explicitly as a type qualifier, e.g. TODO
apos.isStructPos() and
exists(TypeMention tm |
// variant
tm = this.getPath().getQualifier()
or
tm = this.getPath()
|
result = tm.resolveTypeAt(path)
)
}

Declaration getTarget() { result = resolvePath(this.getPath()) }
Expand Down Expand Up @@ -537,15 +534,24 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {

abstract Type getReturnType(TypePath path);

final Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = this.getParameterType(dpos, path)
or
dpos.isReturn() and
result = this.getReturnType(path)
}
}

private class TupleStructDecl extends Declaration, Struct {
abstract private class TupleDeclaration extends Declaration {
override Type getDeclaredType(DeclarationPosition dpos, TypePath path) {
result = super.getDeclaredType(dpos, path)
or
dpos.isSelf() and
result = this.getReturnType(path)
}
}

private class TupleStructDecl extends TupleDeclaration, Struct {
TupleStructDecl() { this.isTuple() }

override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
Expand All @@ -568,7 +574,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
}
}

private class TupleVariantDecl extends Declaration, Variant {
private class TupleVariantDecl extends TupleDeclaration, Variant {
TupleVariantDecl() { this.isTuple() }

override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
Expand Down Expand Up @@ -597,13 +603,13 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
override TypeParameter getTypeParameter(TypeParameterPosition ppos) {
typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos)
or
exists(TraitItemNode trait | this = trait.getAnAssocItem() |
typeParamMatchPosition(trait.getTypeParam(_), result, ppos)
exists(ImplOrTraitItemNode i | this = i.getAnAssocItem() |
typeParamMatchPosition(i.getTypeParam(_), result, ppos)
or
ppos.isImplicit() and result = TSelfTypeParameter(trait)
ppos.isImplicit() and result = TSelfTypeParameter(i)
or
ppos.isImplicit() and
result.(AssociatedTypeTypeParameter).getTrait() = trait
result.(AssociatedTypeTypeParameter).getTrait() = i
)
or
ppos.isImplicit() and
Expand All @@ -625,6 +631,33 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
or
result = inferImplicitSelfType(self, path) // `self` parameter without type annotation
)
or
// For associated functions, we may also need to match type arguments against
// the `Self` type. For example, in
//
// ```rust
// struct Foo<T>(T);
//
// impl<T : Default> Foo<T> {
// fn default() -> Self {
// Foo(Default::default())
// }
// }
//
// Foo::<i32>::default();
// ```
//
// we need to match `i32` against the type parameter `T` of the `impl` block.
exists(ImplOrTraitItemNode i |
this = i.getAnAssocItem() and
dpos.isSelf() and
not this.getParamList().hasSelfParam()
|
result = TSelfTypeParameter(i) and
path.isEmpty()
or
result = resolveImplSelfType(i, path)
)
}

private Type resolveRetType(TypePath path) {
Expand Down Expand Up @@ -670,9 +703,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
private import codeql.rust.elements.internal.CallExprImpl::Impl as CallExprImpl

final class Access extends Call {
pragma[nomagic]
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
exists(TypeMention arg | result = arg.resolveTypeAt(path) |
arg = getExplicitTypeArgMention(CallExprImpl::getFunctionPath(this), apos.asTypeParam())
exists(Path p, int i |
p = CallExprImpl::getFunctionPath(this) and
arg = p.getSegment().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and
apos.asTypeParam() = resolvePath(p).getTypeParam(pragma[only_bind_into](i))
)
or
arg =
this.(MethodCallExpr).getGenericArgList().getTypeArg(apos.asMethodTypeArgumentPosition())
Expand All @@ -696,6 +734,14 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {

Type getInferredType(AccessPosition apos, TypePath path) {
result = inferType(this.getNodeAt(apos), path)
or
// The `Self` type is supplied explicitly as a type qualifier, e.g. `Foo::<Bar>::baz()`
apos = TArgumentAccessPosition(CallImpl::TSelfArgumentPosition(), false, false) and
exists(PathExpr pe, TypeMention tm |
pe = this.(CallExpr).getFunction() and
tm = pe.getPath().getQualifier() and
result = tm.resolveTypeAt(path)
)
}

Declaration getTarget() {
Expand Down Expand Up @@ -1110,12 +1156,7 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
}

final class MethodCall extends Call {
MethodCall() {
exists(this.getReceiver()) and
// We want the method calls that don't have a path to a concrete method in
// an impl block. We need to exclude calls like `MyType::my_method(..)`.
(this instanceof CallExpr implies exists(this.getTrait()))
}
MethodCall() { exists(this.getReceiver()) }

/** Gets the type of the receiver of the method call at `path`. */
Type getTypeAt(TypePath path) {
Expand Down
Loading