Skip to content

Commit 57bd7d0

Browse files
authored
[ty] Simplify KnownClass::check_call() and KnownFunction::check_call() (#18981)
1 parent 3c18d85 commit 57bd7d0

File tree

3 files changed

+64
-76
lines changed

3 files changed

+64
-76
lines changed

crates/ty_python_semantic/src/types/class.rs

Lines changed: 46 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3201,16 +3201,16 @@ impl KnownClass {
32013201
}
32023202
}
32033203

3204-
/// Evaluate a call to this known class, and emit any diagnostics that are necessary
3205-
/// as a result of the call.
3204+
/// Evaluate a call to this known class, emit any diagnostics that are necessary
3205+
/// as a result of the call, and return the type that results from the call.
32063206
pub(super) fn check_call<'db>(
32073207
self,
32083208
context: &InferContext<'db, '_>,
32093209
index: &SemanticIndex<'db>,
3210-
overload_binding: &mut Binding<'db>,
3210+
overload_binding: &Binding<'db>,
32113211
call_argument_types: &CallArgumentTypes<'_, 'db>,
32123212
call_expression: &ast::ExprCall,
3213-
) {
3213+
) -> Option<Type<'db>> {
32143214
let db = context.db();
32153215
let scope = context.scope();
32163216
let module = context.module();
@@ -3226,10 +3226,9 @@ impl KnownClass {
32263226
let Some(enclosing_class) =
32273227
nearest_enclosing_class(db, index, scope, module)
32283228
else {
3229-
overload_binding.set_return_type(Type::unknown());
32303229
BoundSuperError::UnavailableImplicitArguments
32313230
.report_diagnostic(context, call_expression.into());
3232-
return;
3231+
return Some(Type::unknown());
32333232
};
32343233

32353234
// The type of the first parameter if the given scope is function-like (i.e. function or lambda).
@@ -3249,10 +3248,9 @@ impl KnownClass {
32493248
};
32503249

32513250
let Some(first_param) = first_param else {
3252-
overload_binding.set_return_type(Type::unknown());
32533251
BoundSuperError::UnavailableImplicitArguments
32543252
.report_diagnostic(context, call_expression.into());
3255-
return;
3253+
return Some(Type::unknown());
32563254
};
32573255

32583256
let definition = index.expect_single_definition(first_param);
@@ -3269,7 +3267,7 @@ impl KnownClass {
32693267
Type::unknown()
32703268
});
32713269

3272-
overload_binding.set_return_type(bound_super);
3270+
Some(bound_super)
32733271
}
32743272
[Some(pivot_class_type), Some(owner_type)] => {
32753273
let bound_super = BoundSuperType::build(db, *pivot_class_type, *owner_type)
@@ -3278,9 +3276,9 @@ impl KnownClass {
32783276
Type::unknown()
32793277
});
32803278

3281-
overload_binding.set_return_type(bound_super);
3279+
Some(bound_super)
32823280
}
3283-
_ => {}
3281+
_ => None,
32843282
}
32853283
}
32863284

@@ -3295,14 +3293,12 @@ impl KnownClass {
32953293
_ => None,
32963294
}
32973295
}) else {
3298-
if let Some(builder) =
3299-
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)
3300-
{
3301-
builder.into_diagnostic(
3302-
"A legacy `typing.TypeVar` must be immediately assigned to a variable",
3303-
);
3304-
}
3305-
return;
3296+
let builder =
3297+
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
3298+
builder.into_diagnostic(
3299+
"A legacy `typing.TypeVar` must be immediately assigned to a variable",
3300+
);
3301+
return None;
33063302
};
33073303

33083304
let [
@@ -3315,7 +3311,7 @@ impl KnownClass {
33153311
_infer_variance,
33163312
] = overload_binding.parameter_types()
33173313
else {
3318-
return;
3314+
return None;
33193315
};
33203316

33213317
let covariant = covariant
@@ -3328,39 +3324,30 @@ impl KnownClass {
33283324

33293325
let variance = match (contravariant, covariant) {
33303326
(Truthiness::Ambiguous, _) => {
3331-
let Some(builder) =
3332-
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)
3333-
else {
3334-
return;
3335-
};
3327+
let builder =
3328+
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
33363329
builder.into_diagnostic(
33373330
"The `contravariant` parameter of a legacy `typing.TypeVar` \
33383331
cannot have an ambiguous value",
33393332
);
3340-
return;
3333+
return None;
33413334
}
33423335
(_, Truthiness::Ambiguous) => {
3343-
let Some(builder) =
3344-
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)
3345-
else {
3346-
return;
3347-
};
3336+
let builder =
3337+
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
33483338
builder.into_diagnostic(
33493339
"The `covariant` parameter of a legacy `typing.TypeVar` \
33503340
cannot have an ambiguous value",
33513341
);
3352-
return;
3342+
return None;
33533343
}
33543344
(Truthiness::AlwaysTrue, Truthiness::AlwaysTrue) => {
3355-
let Some(builder) =
3356-
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)
3357-
else {
3358-
return;
3359-
};
3345+
let builder =
3346+
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
33603347
builder.into_diagnostic(
33613348
"A legacy `typing.TypeVar` cannot be both covariant and contravariant",
33623349
);
3363-
return;
3350+
return None;
33643351
}
33653352
(Truthiness::AlwaysTrue, Truthiness::AlwaysFalse) => {
33663353
TypeVarVariance::Contravariant
@@ -3374,11 +3361,8 @@ impl KnownClass {
33743361
let name_param = name_param.into_string_literal().map(|name| name.value(db));
33753362

33763363
if name_param.is_none_or(|name_param| name_param != target.id) {
3377-
let Some(builder) =
3378-
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)
3379-
else {
3380-
return;
3381-
};
3364+
let builder =
3365+
context.report_lint(&INVALID_LEGACY_TYPE_VARIABLE, call_expression)?;
33823366
builder.into_diagnostic(format_args!(
33833367
"The name of a legacy `typing.TypeVar`{} must match \
33843368
the name of the variable it is assigned to (`{}`)",
@@ -3389,7 +3373,7 @@ impl KnownClass {
33893373
},
33903374
target.id,
33913375
));
3392-
return;
3376+
return None;
33933377
}
33943378

33953379
let bound_or_constraint = match (bound, constraints) {
@@ -3414,13 +3398,13 @@ impl KnownClass {
34143398

34153399
// TODO: Emit a diagnostic that TypeVar cannot be both bounded and
34163400
// constrained
3417-
(Some(_), Some(_)) => return,
3401+
(Some(_), Some(_)) => return None,
34183402

34193403
(None, None) => None,
34203404
};
34213405

34223406
let containing_assignment = index.expect_single_definition(target);
3423-
overload_binding.set_return_type(Type::KnownInstance(KnownInstanceType::TypeVar(
3407+
Some(Type::KnownInstance(KnownInstanceType::TypeVar(
34243408
TypeVarInstance::new(
34253409
db,
34263410
target.id.clone(),
@@ -3430,7 +3414,7 @@ impl KnownClass {
34303414
*default,
34313415
TypeVarKind::Legacy,
34323416
),
3433-
)));
3417+
)))
34343418
}
34353419

34363420
KnownClass::TypeAliasType => {
@@ -3446,30 +3430,31 @@ impl KnownClass {
34463430
});
34473431

34483432
let [Some(name), Some(value), ..] = overload_binding.parameter_types() else {
3449-
return;
3433+
return None;
34503434
};
34513435

3452-
if let Some(name) = name.into_string_literal() {
3453-
overload_binding.set_return_type(Type::KnownInstance(
3454-
KnownInstanceType::TypeAliasType(TypeAliasType::Bare(
3436+
name.into_string_literal()
3437+
.map(|name| {
3438+
Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::Bare(
34553439
BareTypeAliasType::new(
34563440
db,
34573441
ast::name::Name::new(name.value(db)),
34583442
containing_assignment,
34593443
value,
34603444
),
3461-
)),
3462-
));
3463-
} else if let Some(builder) =
3464-
context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression)
3465-
{
3466-
builder.into_diagnostic(
3467-
"The name of a `typing.TypeAlias` must be a string literal",
3468-
);
3469-
}
3445+
)))
3446+
})
3447+
.or_else(|| {
3448+
let builder =
3449+
context.report_lint(&INVALID_TYPE_ALIAS_TYPE, call_expression)?;
3450+
builder.into_diagnostic(
3451+
"The name of a `typing.TypeAlias` must be a string literal",
3452+
);
3453+
None
3454+
})
34703455
}
34713456

3472-
_ => {}
3457+
_ => None,
34733458
}
34743459
}
34753460
}

crates/ty_python_semantic/src/types/function.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ use crate::types::generics::GenericContext;
7474
use crate::types::narrow::ClassInfoConstraintFunction;
7575
use crate::types::signatures::{CallableSignature, Signature};
7676
use crate::types::{
77-
Binding, BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation,
78-
TypeVarInstance,
77+
BoundMethodType, CallableType, DynamicType, Type, TypeMapping, TypeRelation, TypeVarInstance,
7978
};
8079
use crate::{Db, FxOrderSet};
8180

@@ -963,14 +962,14 @@ impl KnownFunction {
963962
pub(super) fn check_call(
964963
self,
965964
context: &InferContext,
966-
overload_binding: &mut Binding,
965+
parameter_types: &[Option<Type<'_>>],
967966
call_expression: &ast::ExprCall,
968967
) {
969968
let db = context.db();
970969

971970
match self {
972971
KnownFunction::RevealType => {
973-
let [Some(revealed_type)] = overload_binding.parameter_types() else {
972+
let [Some(revealed_type)] = parameter_types else {
974973
return;
975974
};
976975
let Some(builder) =
@@ -986,8 +985,7 @@ impl KnownFunction {
986985
);
987986
}
988987
KnownFunction::AssertType => {
989-
let [Some(actual_ty), Some(asserted_ty)] = overload_binding.parameter_types()
990-
else {
988+
let [Some(actual_ty), Some(asserted_ty)] = parameter_types else {
991989
return;
992990
};
993991

@@ -1019,7 +1017,7 @@ impl KnownFunction {
10191017
));
10201018
}
10211019
KnownFunction::AssertNever => {
1022-
let [Some(actual_ty)] = overload_binding.parameter_types() else {
1020+
let [Some(actual_ty)] = parameter_types else {
10231021
return;
10241022
};
10251023
if actual_ty.is_equivalent_to(db, Type::Never) {
@@ -1045,7 +1043,7 @@ impl KnownFunction {
10451043
));
10461044
}
10471045
KnownFunction::StaticAssert => {
1048-
let [Some(parameter_ty), message] = overload_binding.parameter_types() else {
1046+
let [Some(parameter_ty), message] = parameter_types else {
10491047
return;
10501048
};
10511049
let truthiness = match parameter_ty.try_bool(db) {
@@ -1100,8 +1098,7 @@ impl KnownFunction {
11001098
}
11011099
}
11021100
KnownFunction::Cast => {
1103-
let [Some(casted_type), Some(source_type)] = overload_binding.parameter_types()
1104-
else {
1101+
let [Some(casted_type), Some(source_type)] = parameter_types else {
11051102
return;
11061103
};
11071104
let contains_unknown_or_todo =
@@ -1121,7 +1118,7 @@ impl KnownFunction {
11211118
}
11221119
}
11231120
KnownFunction::GetProtocolMembers => {
1124-
let [Some(Type::ClassLiteral(class))] = overload_binding.parameter_types() else {
1121+
let [Some(Type::ClassLiteral(class))] = parameter_types else {
11251122
return;
11261123
};
11271124
if class.is_protocol(db) {
@@ -1130,8 +1127,7 @@ impl KnownFunction {
11301127
report_bad_argument_to_get_protocol_members(context, call_expression, *class);
11311128
}
11321129
KnownFunction::IsInstance | KnownFunction::IsSubclass => {
1133-
let [_, Some(Type::ClassLiteral(class))] = overload_binding.parameter_types()
1134-
else {
1130+
let [_, Some(Type::ClassLiteral(class))] = parameter_types else {
11351131
return;
11361132
};
11371133
let Some(protocol_class) = class.into_protocol_class(db) else {

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5397,21 +5397,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53975397
match binding_type {
53985398
Type::FunctionLiteral(function_literal) => {
53995399
if let Some(known_function) = function_literal.known(self.db()) {
5400-
known_function.check_call(&self.context, overload, call_expression);
5400+
known_function.check_call(
5401+
&self.context,
5402+
overload.parameter_types(),
5403+
call_expression,
5404+
);
54015405
}
54025406
}
54035407

54045408
Type::ClassLiteral(class) => {
54055409
let Some(known_class) = class.known(self.db()) else {
54065410
continue;
54075411
};
5408-
known_class.check_call(
5412+
let overridden_return = known_class.check_call(
54095413
&self.context,
54105414
self.index,
54115415
overload,
54125416
&call_argument_types,
54135417
call_expression,
54145418
);
5419+
if let Some(overridden_return) = overridden_return {
5420+
overload.set_return_type(overridden_return);
5421+
}
54155422
}
54165423
_ => {}
54175424
}

0 commit comments

Comments
 (0)