Skip to content

Commit ee9d4fe

Browse files
committed
wip
1 parent cad383f commit ee9d4fe

2 files changed

Lines changed: 40 additions & 47 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -232,25 +232,14 @@ private module Input3 implements InputSig3 {
232232
(exists(resolveTupleFieldExpr(_, _)) implies any())
233233
}
234234

235+
predicate inferTypeForDefaults = M3::inferType/2;
236+
235237
class BoolType extends DataType {
236238
BoolType() { this.getTypeItem() instanceof Builtins::Bool }
237239
}
238240

239241
class AstNode = Rust::AstNode;
240242

241-
// todo: remove
242-
TypeMention getTypeAnnotation(AstNode n) {
243-
exists(Static static |
244-
n = static and
245-
result = static.getTypeRepr()
246-
)
247-
or
248-
exists(Const c |
249-
n = c and
250-
result = c.getTypeRepr()
251-
)
252-
}
253-
254243
class Expr = Rust::Expr;
255244

256245
class Cast extends CastExpr {
@@ -978,6 +967,18 @@ private module Input3 implements InputSig3 {
978967
result instanceof UnitType
979968
or
980969
result = inferClosureArgsType(n, path)
970+
or
971+
exists(TypeMention tm | result = tm.getTypeAt(path) |
972+
exists(Static static |
973+
n = static and
974+
tm = static.getTypeRepr()
975+
)
976+
or
977+
exists(Const c |
978+
n = c and
979+
tm = c.getTypeRepr()
980+
)
981+
)
981982
}
982983

983984
predicate inferStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
@@ -1085,8 +1086,6 @@ private module Input3 implements InputSig3 {
10851086
or
10861087
result = inferUnknownType(n, path)
10871088
}
1088-
1089-
predicate inferTypeForDefaults = M3::inferType/2;
10901089
}
10911090

10921091
private module M3 = Make3<Input3>;

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
* parameters) and constructs the `TypePath` type used to represent paths into
88
* compound types.
99
*
10-
* 2. `Make2`, which takes as input a definition of type mentions (using the `TypePath`
11-
* type) as well as the type hierarchy and type constraints, and constructs the
10+
* 2. `Make2`, which (using the `TypePath` type) takes as input a definition of type
11+
* mentions as well as the type hierarchy and type constraints, and constructs the
1212
* `Matching` and `IsInstantiationOf` modules, which are core building blocks for
1313
* matching type instantiations against type parameters, taking the type hierarchy
1414
* and type constraints into account.
@@ -47,8 +47,8 @@
4747
*
4848
* where the type of `Default::default()` needs to be inferred from the context, we
4949
*
50-
* 1. assign `Default::default()` the special `UnknownType`,
51-
* 2. using the `cond-then` rule we conclude that the conditional has type `i64`, and
50+
* 1. conclude that the conditional has type `i64`, using the `cond-then` rule,
51+
* 2. assign `Default::default()` the special `UnknownType`, and
5252
* 3. since the `else` branch has `UnknownType`, we apply the `cond-else` rule _backwards_
5353
* to infer that `Default::default()` has type `i64`.
5454
*
@@ -1563,9 +1563,11 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
15631563
) {
15641564
constraint = getATypeParameterConstraint(constrainedTp, target) and
15651565
(
1566-
constrainedTp = target.getDeclaredType(_, _)
1567-
or
15681566
constrainedTp = target.getTypeParameter(_)
1567+
or
1568+
// a declaration may reference type parameters that are not declared on it,
1569+
// for type parameters from the enclosing type
1570+
constrainedTp = target.getDeclaredType(_, _)
15691571
)
15701572
}
15711573

@@ -1929,6 +1931,13 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
19291931
*/
19301932
default predicate cacheRevRef() { none() }
19311933

1934+
/**
1935+
* This predicate must be implemented as an alias for the the `inferType` predicate
1936+
* defined in this module, and is needed in order to provide default implementations
1937+
* inside this signature.
1938+
*/
1939+
Type inferTypeForDefaults(AstNode n, TypePath path);
1940+
19321941
/** A boolean type. */
19331942
class BoolType extends Type;
19341943

@@ -1941,10 +1950,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
19411950
Location getLocation();
19421951
}
19431952

1944-
// todo: remove
1945-
/** Gets the type annotation that applies to `n`, if any. */
1946-
TypeMention getTypeAnnotation(AstNode n);
1947-
19481953
/** An expression. */
19491954
class Expr extends AstNode;
19501955

@@ -2028,12 +2033,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
20282033
* respect to type inference, for example a `const` or `static` in Rust.
20292034
*/
20302035
class LocalVariable {
2031-
/**
2032-
* Gets the AST node that defines this variable.
2033-
*
2034-
* If this variable is explicitly typed, then the type annotation must be
2035-
* applied to the defining node in `getTypeAnnotation`.
2036-
*/
2036+
/** Gets the AST node that defines this variable. */
20372037
AstNode getDefiningNode();
20382038

20392039
/** Gets an access to this variable. */
@@ -2381,15 +2381,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23812381
* nodes where `inferStep` cannot be used, such as leaf nodes in the AST.
23822382
*/
23832383
Type inferTypeSpecific(AstNode n, TypePath path);
2384-
2385-
/**
2386-
* Gets the inferred type of `n` at `path`.
2387-
*
2388-
* This predicate must be implemented as an alias for the the `inferType` predicate
2389-
* defined in this module, and is needed in order to provide default implementations
2390-
* inside this signature.
2391-
*/
2392-
Type inferTypeForDefaults(AstNode n, TypePath path);
23932384
}
23942385

23952386
module Make3<InputSig3 Input3> {
@@ -2398,8 +2389,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23982389
/** Gets the type of `n`, which has an explicit type annotation. */
23992390
pragma[nomagic]
24002391
private Type inferAnnotatedType(AstNode n, TypePath path) {
2401-
result = getTypeAnnotation(n).getTypeAt(path)
2402-
or
24032392
result = n.(Cast).getType().getTypeAt(path)
24042393
or
24052394
exists(LocalVariableDeclaration decl |
@@ -2961,7 +2950,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
29612950
class Declaration extends ParameterizableFinal {
29622951
TypeParameter getTypeParameter(int pos) {
29632952
InvocationTypeQualifierMatching::typeMatch(_, _, this, _, _, result) and
2964-
pos = -getTypeParameterId(result) - 1
2953+
pos = -getTypeParameterId(result) - 2
2954+
or
2955+
// blanket implementations in Rust have a declaring type that is a type parameter;
2956+
// those should be matched against the entire type qualifier
2957+
result = this.getDeclaringType(TypePath::nil()) and
2958+
pos = -1
29652959
or
29662960
result = super.getTypeParameter(pos)
29672961
// none()
@@ -2972,9 +2966,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
29722966
or
29732967
dpos = getReturnPosition() and
29742968
result = getParameterizableType(this, path)
2975-
// or
2976-
// dpos = getDeclaringPosition() and
2977-
// result = this.getDeclaringType(path)
29782969
}
29792970
}
29802971

@@ -2989,9 +2980,12 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
29892980
Type getTypeArgument(int pos, TypePath path) {
29902981
exists(TypeParameter tp |
29912982
InvocationTypeQualifierMatching::typeMatch(this, _, _, path, result, tp) and
2992-
pos = -getTypeParameterId(tp) - 1
2983+
pos = -getTypeParameterId(tp) - 2
29932984
)
29942985
or
2986+
pos = -1 and
2987+
result = this.getTypeQualifier(path)
2988+
or
29952989
result = super.getTypeArgument(pos, path)
29962990
}
29972991

0 commit comments

Comments
 (0)