Skip to content

Commit 26a04fe

Browse files
authored
Merge pull request #69501 from ktoso/wip-dont-crash-missing-conformance-param
[Distributed] Don't crash in thunk generation when missing SR conformance
2 parents 3f4d242 + 0f5e564 commit 26a04fe

9 files changed

+160
-120
lines changed

include/swift/AST/DistributedDecl.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Type getDistributedActorIDType(NominalTypeDecl *actor);
5050
/// Similar to `getDistributedSerializationRequirementType`, however, from the
5151
/// perspective of a concrete function. This way we're able to get the
5252
/// serialization requirement for specific members, also in protocols.
53-
Type getConcreteReplacementForMemberSerializationRequirement(ValueDecl *member);
53+
Type getSerializationRequirementTypesForMember(
54+
ValueDecl *member, llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements);
5455

5556
/// Get specific 'SerializationRequirement' as defined in 'nominal'
5657
/// type, which must conform to the passed 'protocol' which is expected
@@ -97,7 +98,7 @@ getDistributedSerializationRequirementProtocols(
9798
/// If so, we can emit slightly nicer diagnostics.
9899
bool checkDistributedSerializationRequirementIsExactlyCodable(
99100
ASTContext &C,
100-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements);
101+
Type type);
101102

102103
/// Get the `SerializationRequirement`, explode it into the specific
103104
/// protocol requirements and insert them into `requirements`.
@@ -114,15 +115,6 @@ getDistributedSerializationRequirements(
114115
ProtocolDecl *protocol,
115116
llvm::SmallPtrSetImpl<ProtocolDecl *> &requirementProtos);
116117

117-
/// Given any set of generic requirements, locate those which are about the
118-
/// `SerializationRequirement`. Those need to be applied in the parameter and
119-
/// return type checking of distributed targets.
120-
llvm::SmallPtrSet<ProtocolDecl *, 2>
121-
extractDistributedSerializationRequirements(
122-
ASTContext &C, ArrayRef<Requirement> allRequirements);
123-
124118
}
125119

126-
// ==== ------------------------------------------------------------------------
127-
128120
#endif /* SWIFT_DECL_DISTRIBUTEDDECL_H */

lib/AST/DistributedDecl.cpp

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ Type swift::getConcreteReplacementForProtocolActorSystemType(ValueDecl *member)
9595
llvm_unreachable("Unable to fetch ActorSystem type!");
9696
}
9797

98-
Type swift::getConcreteReplacementForMemberSerializationRequirement(
99-
ValueDecl *member) {
98+
Type swift::getSerializationRequirementTypesForMember(
99+
ValueDecl *member,
100+
llvm::SmallPtrSet<ProtocolDecl *, 2> &serializationRequirements) {
100101
auto &C = member->getASTContext();
101102
auto *DC = member->getDeclContext();
102103
auto DA = C.getDistributedActorDecl();
@@ -106,17 +107,28 @@ Type swift::getConcreteReplacementForMemberSerializationRequirement(
106107
return getDistributedSerializationRequirementType(classDecl, C.getDistributedActorDecl());
107108
}
108109

109-
/// === Maybe the value is declared in a protocol?
110-
if (auto protocol = DC->getSelfProtocolDecl()) {
110+
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
111+
->getDeclaredInterfaceType();
112+
113+
if (DC->getSelfProtocolDecl() || isa<ExtensionDecl>(DC)) {
111114
GenericSignature signature;
112115
if (auto *genericContext = member->getAsGenericContext()) {
113116
signature = genericContext->getGenericSignature();
114117
} else {
115118
signature = DC->getGenericSignatureOfContext();
116119
}
117120

118-
auto SerReqAssocType = DA->getAssociatedType(C.Id_SerializationRequirement)
119-
->getDeclaredInterfaceType();
121+
// Also store all `SerializationRequirement : SomeProtocol` requirements
122+
for (auto requirement: signature.getRequirements()) {
123+
if (requirement.getFirstType()->isEqual(SerReqAssocType) &&
124+
requirement.getKind() == RequirementKind::Conformance) {
125+
if (auto nominal = requirement.getSecondType()->getAnyNominal()) {
126+
if (auto protocol = dyn_cast<ProtocolDecl>(nominal)) {
127+
serializationRequirements.insert(protocol);
128+
}
129+
}
130+
}
131+
}
120132

121133
// Note that this may be null, e.g. if we're a distributed func inside
122134
// a protocol that did not declare a specific actor system requirement.
@@ -355,15 +367,24 @@ swift::getDistributedSerializationRequirements(
355367

356368
bool swift::checkDistributedSerializationRequirementIsExactlyCodable(
357369
ASTContext &C,
358-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &allRequirements) {
370+
Type type) {
371+
if (!type)
372+
return false;
373+
374+
if (type->hasError())
375+
return false;
376+
359377
auto encodable = C.getProtocol(KnownProtocolKind::Encodable);
360378
auto decodable = C.getProtocol(KnownProtocolKind::Decodable);
361379

362-
if (allRequirements.size() != 2)
380+
auto layout = type->getExistentialLayout();
381+
auto protocols = layout.getProtocols();
382+
383+
if (protocols.size() != 2)
363384
return false;
364385

365-
return allRequirements.count(encodable) &&
366-
allRequirements.count(decodable);
386+
return std::count(protocols.begin(), protocols.end(), encodable) == 1 &&
387+
std::count(protocols.begin(), protocols.end(), decodable) == 1;
367388
}
368389

369390
/******************************************************************************/
@@ -1214,34 +1235,6 @@ AbstractFunctionDecl::isDistributedTargetInvocationResultHandlerOnReturn() const
12141235
return true;
12151236
}
12161237

1217-
llvm::SmallPtrSet<ProtocolDecl *, 2>
1218-
swift::extractDistributedSerializationRequirements(
1219-
ASTContext &C, ArrayRef<Requirement> allRequirements) {
1220-
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationReqs;
1221-
auto DA = C.getDistributedActorDecl();
1222-
auto daSerializationReqAssocType =
1223-
DA->getAssociatedType(C.Id_SerializationRequirement);
1224-
1225-
for (auto req : allRequirements) {
1226-
// FIXME: Seems unprincipled
1227-
if (req.getKind() != RequirementKind::SameType &&
1228-
req.getKind() != RequirementKind::Conformance)
1229-
continue;
1230-
1231-
if (auto dependentMemberType =
1232-
req.getFirstType()->getAs<DependentMemberType>()) {
1233-
if (dependentMemberType->getAssocType() == daSerializationReqAssocType) {
1234-
auto layout = req.getSecondType()->getExistentialLayout();
1235-
for (auto p : layout.getProtocols()) {
1236-
serializationReqs.insert(p);
1237-
}
1238-
}
1239-
}
1240-
}
1241-
1242-
return serializationReqs;
1243-
}
1244-
12451238
/******************************************************************************/
12461239
/********************** Distributed Functions *********************************/
12471240
/******************************************************************************/

lib/Sema/CodeSynthesisDistributedActor.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,15 @@ FuncDecl *GetDistributedThunkRequest::evaluate(Evaluator &evaluator,
842842
if (!distributedTarget->isDistributed())
843843
return nullptr;
844844
}
845-
846845
assert(distributedTarget);
847846

847+
// This evaluation type-check by now was already computed and cached;
848+
// We need to check in order to avoid emitting a THUNK for a distributed func
849+
// which had errors; as the thunk then may also cause un-addressable issues and confusion.
850+
if (swift::checkDistributedFunction(distributedTarget)) {
851+
return nullptr;
852+
}
853+
848854
auto &C = distributedTarget->getASTContext();
849855

850856
if (!getConcreteReplacementForProtocolActorSystemType(distributedTarget)) {

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2071,7 +2071,7 @@ static bool checkSingleOverride(ValueDecl *override, ValueDecl *base) {
20712071
return (prop &&
20722072
prop->isFinal() &&
20732073
isa<ClassDecl>(prop->getDeclContext()) &&
2074-
cast<ClassDecl>(prop->getDeclContext())->isActor() &&
2074+
cast<ClassDecl>(prop->getDeclContext())->isAnyActor() &&
20752075
!prop->isStatic() &&
20762076
prop->getName() == ctx.Id_unownedExecutor &&
20772077
prop->getInterfaceType()->getAnyNominal() == ctx.getUnownedSerialExecutorDecl());

lib/Sema/TypeCheckDistributed.cpp

Lines changed: 65 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,18 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
376376

377377
static bool checkDistributedTargetResultType(
378378
ModuleDecl *module, ValueDecl *valueDecl,
379-
const llvm::SmallPtrSetImpl<ProtocolDecl *> &serializationRequirements,
379+
Type serializationRequirement,
380+
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements,
380381
bool diagnose) {
381382
auto &C = valueDecl->getASTContext();
382383

384+
if (serializationRequirement && serializationRequirement->hasError()) {
385+
return false;
386+
}
387+
if ((!serializationRequirement || serializationRequirement->hasError()) && serializationRequirements.empty()) {
388+
return false; // error of the type would be diagnosed elsewhere
389+
}
390+
383391
Type resultType;
384392
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
385393
resultType = func->mapTypeIntoContext(func->getResultInterfaceType());
@@ -392,18 +400,27 @@ static bool checkDistributedTargetResultType(
392400
if (resultType->isVoid())
393401
return false;
394402

403+
404+
// Collect extra "SerializationRequirement: SomeProtocol" requirements
405+
if (serializationRequirement && !serializationRequirement->hasError()) {
406+
auto srl = serializationRequirement->getExistentialLayout();
407+
for (auto s: srl.getProtocols()) {
408+
serializationRequirements.insert(s);
409+
}
410+
}
411+
395412
auto isCodableRequirement =
396413
checkDistributedSerializationRequirementIsExactlyCodable(
397-
C, serializationRequirements);
414+
C, serializationRequirement);
398415

399-
for(auto serializationReq : serializationRequirements) {
416+
for (auto serializationReq: serializationRequirements) {
400417
auto conformance =
401418
TypeChecker::conformsToProtocol(resultType, serializationReq, module);
402419
if (conformance.isInvalid()) {
403420
if (diagnose) {
404421
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
405-
"Codable" : // Codable is a typealias, easier to diagnose like that
406-
serializationReq->getNameStr();
422+
"Codable" : // Codable is a typealias, easier to diagnose like that
423+
serializationReq->getNameStr();
407424

408425
auto diag = valueDecl->diagnose(
409426
diag::distributed_actor_target_result_not_codable,
@@ -418,12 +435,12 @@ static bool checkDistributedTargetResultType(
418435
}
419436
}
420437
} // end if: diagnose
421-
438+
422439
return true;
423440
}
424441
}
425442

426-
return false;
443+
return false;
427444
}
428445

429446
bool swift::checkDistributedActorSystem(const NominalTypeDecl *system) {
@@ -487,66 +504,42 @@ bool CheckDistributedFunctionRequest::evaluate(
487504
}
488505

489506
auto &C = func->getASTContext();
490-
auto DC = func->getDeclContext();
491507
auto module = func->getParentModule();
492508

493509
/// If no distributed module is available, then no reason to even try checks.
494510
if (!C.getLoadedModule(C.Id_Distributed))
495511
return true;
496512

497-
// === All parameters and the result type must conform
498-
// SerializationRequirement
499513
llvm::SmallPtrSet<ProtocolDecl *, 2> serializationRequirements;
500-
if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
501-
serializationRequirements = extractDistributedSerializationRequirements(
502-
C, extension->getGenericRequirements());
503-
} else if (auto actor = dyn_cast<ClassDecl>(DC)) {
504-
serializationRequirements = getDistributedSerializationRequirementProtocols(
505-
getDistributedActorSystemType(actor)->getAnyNominal(),
506-
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
507-
} else if (isa<ProtocolDecl>(DC)) {
508-
if (auto seqReqTy =
509-
getConcreteReplacementForMemberSerializationRequirement(func)) {
510-
auto layout = seqReqTy->getExistentialLayout();
511-
for (auto req : layout.getProtocols()) {
512-
serializationRequirements.insert(req);
513-
}
514-
}
515-
516-
// The distributed actor constrained protocol has no serialization requirements
517-
// or actor system defined, so these will only be enforced, by implementations
518-
// of DAs conforming to it, skip checks here.
519-
if (serializationRequirements.empty()) {
520-
return false;
521-
}
522-
} else {
523-
llvm_unreachable("Distributed function detected in type other than extension, "
524-
"distributed actor, or protocol! This should not be possible "
525-
", please file a bug.");
526-
}
527-
528-
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
529-
auto serializationRequirementIsCodable =
530-
checkDistributedSerializationRequirementIsExactlyCodable(
531-
C, serializationRequirements);
532-
533-
for (auto param : *func->getParameters()) {
534-
// --- Check parameters for 'Codable' conformance
535-
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
536-
537-
for (auto req : serializationRequirements) {
538-
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
539-
auto diag = func->diagnose(
540-
diag::distributed_actor_func_param_not_codable,
541-
param->getArgumentName().str(), param->getInterfaceType(),
542-
func->getDescriptiveKind(),
543-
serializationRequirementIsCodable ? "Codable"
544-
: req->getNameStr());
545-
546-
if (auto paramNominalTy = paramTy->getAnyNominal()) {
547-
addCodableFixIt(paramNominalTy, diag);
548-
} // else, no nominal type to suggest the fixit for, e.g. a closure
549-
return true;
514+
Type serializationReqType = getSerializationRequirementTypesForMember(func, serializationRequirements);
515+
516+
for (auto param: *func->getParameters()) {
517+
// --- Check the parameter conforming to serialization requirements
518+
if (serializationReqType && !serializationReqType->hasError()) {
519+
// If the requirement is exactly `Codable` we diagnose it ia bit nicer.
520+
auto serializationRequirementIsCodable =
521+
checkDistributedSerializationRequirementIsExactlyCodable(
522+
C, serializationReqType);
523+
524+
// --- Check parameters for 'SerializationRequirement' conformance
525+
auto paramTy = func->mapTypeIntoContext(param->getInterfaceType());
526+
527+
auto srl = serializationReqType->getExistentialLayout();
528+
for (auto req: srl.getProtocols()) {
529+
if (TypeChecker::conformsToProtocol(paramTy, req, module).isInvalid()) {
530+
auto diag = func->diagnose(
531+
diag::distributed_actor_func_param_not_codable,
532+
param->getArgumentName().str(), param->getInterfaceType(),
533+
func->getDescriptiveKind(),
534+
serializationRequirementIsCodable ? "Codable"
535+
: req->getNameStr());
536+
537+
if (auto paramNominalTy = paramTy->getAnyNominal()) {
538+
addCodableFixIt(paramNominalTy, diag);
539+
} // else, no nominal type to suggest the fixit for, e.g. a closure
540+
541+
return true;
542+
}
550543
}
551544
}
552545

@@ -583,9 +576,10 @@ bool CheckDistributedFunctionRequest::evaluate(
583576
}
584577
}
585578

586-
// --- Result type must be either void or a codable type
587-
if (checkDistributedTargetResultType(module, func, serializationRequirements,
588-
/*diagnose=*/true)) {
579+
// --- Result type must be either void or a serialization requirement conforming type
580+
if (checkDistributedTargetResultType(
581+
module, func, serializationReqType, serializationRequirements,
582+
/*diagnose=*/true)) {
589583
return true;
590584
}
591585

@@ -639,8 +633,11 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
639633
systemDecl,
640634
C.getProtocol(KnownProtocolKind::DistributedActorSystem));
641635

636+
auto serializationRequirement =
637+
getSerializationRequirementTypesForMember(systemVar, serializationRequirements);
638+
642639
auto module = var->getModuleContext();
643-
if (checkDistributedTargetResultType(module, var, serializationRequirements, diagnose)) {
640+
if (checkDistributedTargetResultType(module, var, serializationRequirement, serializationRequirements, diagnose)) {
644641
return true;
645642
}
646643

@@ -740,13 +737,14 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
740737
(void)nominal->getDistributedActorIDProperty();
741738
}
742739

743-
void TypeChecker::checkDistributedFunc(FuncDecl *func) {
740+
bool TypeChecker::checkDistributedFunc(FuncDecl *func) {
744741
if (!func->isDistributed())
745-
return;
742+
return false;
746743

747-
swift::checkDistributedFunction(func);
744+
return swift::checkDistributedFunction(func);
748745
}
749746

747+
// TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks
750748
llvm::SmallPtrSet<ProtocolDecl *, 2>
751749
swift::getDistributedSerializationRequirementProtocols(
752750
NominalTypeDecl *nominal, ProtocolDecl *protocol) {

lib/Sema/TypeCheckStmt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,6 +2775,14 @@ TypeCheckFunctionBodyRequest::evaluate(Evaluator &eval,
27752775
// So, build out the body now.
27762776
ASTScope::expandFunctionBody(AFD);
27772777

2778+
if (AFD->isDistributedThunk()) {
2779+
if (auto func = dyn_cast<FuncDecl>(AFD)) {
2780+
if (TypeChecker::checkDistributedFunc(func)) {
2781+
return errorBody();
2782+
}
2783+
}
2784+
}
2785+
27782786
// Type check the function body if needed.
27792787
bool hadError = false;
27802788
if (!alreadyTypeChecked) {

0 commit comments

Comments
 (0)