@@ -376,10 +376,18 @@ bool swift::checkDistributedActorSystemAdHocProtocolRequirements(
376
376
377
377
static bool checkDistributedTargetResultType (
378
378
ModuleDecl *module , ValueDecl *valueDecl,
379
- const llvm::SmallPtrSetImpl<ProtocolDecl *> &serializationRequirements,
379
+ Type serializationRequirement,
380
+ llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements,
380
381
bool diagnose) {
381
382
auto &C = valueDecl->getASTContext ();
382
383
384
+ if (serializationRequirement && serializationRequirement->hasError ()) {
385
+ return false ;
386
+ }
387
+ if ((!serializationRequirement || serializationRequirement->hasError ()) && serializationRequirements.empty ()) {
388
+ return false ; // error of the type would be diagnosed elsewhere
389
+ }
390
+
383
391
Type resultType;
384
392
if (auto func = dyn_cast<FuncDecl>(valueDecl)) {
385
393
resultType = func->mapTypeIntoContext (func->getResultInterfaceType ());
@@ -392,18 +400,27 @@ static bool checkDistributedTargetResultType(
392
400
if (resultType->isVoid ())
393
401
return false ;
394
402
403
+
404
+ // Collect extra "SerializationRequirement: SomeProtocol" requirements
405
+ if (serializationRequirement && !serializationRequirement->hasError ()) {
406
+ auto srl = serializationRequirement->getExistentialLayout ();
407
+ for (auto s: srl.getProtocols ()) {
408
+ serializationRequirements.insert (s);
409
+ }
410
+ }
411
+
395
412
auto isCodableRequirement =
396
413
checkDistributedSerializationRequirementIsExactlyCodable (
397
- C, serializationRequirements );
414
+ C, serializationRequirement );
398
415
399
- for (auto serializationReq : serializationRequirements) {
416
+ for (auto serializationReq: serializationRequirements) {
400
417
auto conformance =
401
418
TypeChecker::conformsToProtocol (resultType, serializationReq, module );
402
419
if (conformance.isInvalid ()) {
403
420
if (diagnose) {
404
421
llvm::StringRef conformanceToSuggest = isCodableRequirement ?
405
- " Codable" : // Codable is a typealias, easier to diagnose like that
406
- serializationReq->getNameStr ();
422
+ " Codable" : // Codable is a typealias, easier to diagnose like that
423
+ serializationReq->getNameStr ();
407
424
408
425
auto diag = valueDecl->diagnose (
409
426
diag::distributed_actor_target_result_not_codable,
@@ -418,12 +435,12 @@ static bool checkDistributedTargetResultType(
418
435
}
419
436
}
420
437
} // end if: diagnose
421
-
438
+
422
439
return true ;
423
440
}
424
441
}
425
442
426
- return false ;
443
+ return false ;
427
444
}
428
445
429
446
bool swift::checkDistributedActorSystem (const NominalTypeDecl *system) {
@@ -487,66 +504,42 @@ bool CheckDistributedFunctionRequest::evaluate(
487
504
}
488
505
489
506
auto &C = func->getASTContext ();
490
- auto DC = func->getDeclContext ();
491
507
auto module = func->getParentModule ();
492
508
493
509
// / If no distributed module is available, then no reason to even try checks.
494
510
if (!C.getLoadedModule (C.Id_Distributed ))
495
511
return true ;
496
512
497
- // === All parameters and the result type must conform
498
- // SerializationRequirement
499
513
llvm::SmallPtrSet<ProtocolDecl *, 2 > serializationRequirements;
500
- if (auto extension = dyn_cast<ExtensionDecl>(DC)) {
501
- serializationRequirements = extractDistributedSerializationRequirements (
502
- C, extension->getGenericRequirements ());
503
- } else if (auto actor = dyn_cast<ClassDecl>(DC)) {
504
- serializationRequirements = getDistributedSerializationRequirementProtocols (
505
- getDistributedActorSystemType (actor)->getAnyNominal (),
506
- C.getProtocol (KnownProtocolKind::DistributedActorSystem));
507
- } else if (isa<ProtocolDecl>(DC)) {
508
- if (auto seqReqTy =
509
- getConcreteReplacementForMemberSerializationRequirement (func)) {
510
- auto layout = seqReqTy->getExistentialLayout ();
511
- for (auto req : layout.getProtocols ()) {
512
- serializationRequirements.insert (req);
513
- }
514
- }
515
-
516
- // The distributed actor constrained protocol has no serialization requirements
517
- // or actor system defined, so these will only be enforced, by implementations
518
- // of DAs conforming to it, skip checks here.
519
- if (serializationRequirements.empty ()) {
520
- return false ;
521
- }
522
- } else {
523
- llvm_unreachable (" Distributed function detected in type other than extension, "
524
- " distributed actor, or protocol! This should not be possible "
525
- " , please file a bug." );
526
- }
527
-
528
- // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
529
- auto serializationRequirementIsCodable =
530
- checkDistributedSerializationRequirementIsExactlyCodable (
531
- C, serializationRequirements);
532
-
533
- for (auto param : *func->getParameters ()) {
534
- // --- Check parameters for 'Codable' conformance
535
- auto paramTy = func->mapTypeIntoContext (param->getInterfaceType ());
536
-
537
- for (auto req : serializationRequirements) {
538
- if (TypeChecker::conformsToProtocol (paramTy, req, module ).isInvalid ()) {
539
- auto diag = func->diagnose (
540
- diag::distributed_actor_func_param_not_codable,
541
- param->getArgumentName ().str (), param->getInterfaceType (),
542
- func->getDescriptiveKind (),
543
- serializationRequirementIsCodable ? " Codable"
544
- : req->getNameStr ());
545
-
546
- if (auto paramNominalTy = paramTy->getAnyNominal ()) {
547
- addCodableFixIt (paramNominalTy, diag);
548
- } // else, no nominal type to suggest the fixit for, e.g. a closure
549
- return true ;
514
+ Type serializationReqType = getSerializationRequirementTypesForMember (func, serializationRequirements);
515
+
516
+ for (auto param: *func->getParameters ()) {
517
+ // --- Check the parameter conforming to serialization requirements
518
+ if (serializationReqType && !serializationReqType->hasError ()) {
519
+ // If the requirement is exactly `Codable` we diagnose it ia bit nicer.
520
+ auto serializationRequirementIsCodable =
521
+ checkDistributedSerializationRequirementIsExactlyCodable (
522
+ C, serializationReqType);
523
+
524
+ // --- Check parameters for 'SerializationRequirement' conformance
525
+ auto paramTy = func->mapTypeIntoContext (param->getInterfaceType ());
526
+
527
+ auto srl = serializationReqType->getExistentialLayout ();
528
+ for (auto req: srl.getProtocols ()) {
529
+ if (TypeChecker::conformsToProtocol (paramTy, req, module ).isInvalid ()) {
530
+ auto diag = func->diagnose (
531
+ diag::distributed_actor_func_param_not_codable,
532
+ param->getArgumentName ().str (), param->getInterfaceType (),
533
+ func->getDescriptiveKind (),
534
+ serializationRequirementIsCodable ? " Codable"
535
+ : req->getNameStr ());
536
+
537
+ if (auto paramNominalTy = paramTy->getAnyNominal ()) {
538
+ addCodableFixIt (paramNominalTy, diag);
539
+ } // else, no nominal type to suggest the fixit for, e.g. a closure
540
+
541
+ return true ;
542
+ }
550
543
}
551
544
}
552
545
@@ -583,9 +576,10 @@ bool CheckDistributedFunctionRequest::evaluate(
583
576
}
584
577
}
585
578
586
- // --- Result type must be either void or a codable type
587
- if (checkDistributedTargetResultType (module , func, serializationRequirements,
588
- /* diagnose=*/ true )) {
579
+ // --- Result type must be either void or a serialization requirement conforming type
580
+ if (checkDistributedTargetResultType (
581
+ module , func, serializationReqType, serializationRequirements,
582
+ /* diagnose=*/ true )) {
589
583
return true ;
590
584
}
591
585
@@ -639,8 +633,11 @@ bool swift::checkDistributedActorProperty(VarDecl *var, bool diagnose) {
639
633
systemDecl,
640
634
C.getProtocol (KnownProtocolKind::DistributedActorSystem));
641
635
636
+ auto serializationRequirement =
637
+ getSerializationRequirementTypesForMember (systemVar, serializationRequirements);
638
+
642
639
auto module = var->getModuleContext ();
643
- if (checkDistributedTargetResultType (module , var, serializationRequirements, diagnose)) {
640
+ if (checkDistributedTargetResultType (module , var, serializationRequirement, serializationRequirements, diagnose)) {
644
641
return true ;
645
642
}
646
643
@@ -740,13 +737,14 @@ void TypeChecker::checkDistributedActor(SourceFile *SF, NominalTypeDecl *nominal
740
737
(void )nominal->getDistributedActorIDProperty ();
741
738
}
742
739
743
- void TypeChecker::checkDistributedFunc (FuncDecl *func) {
740
+ bool TypeChecker::checkDistributedFunc (FuncDecl *func) {
744
741
if (!func->isDistributed ())
745
- return ;
742
+ return false ;
746
743
747
- swift::checkDistributedFunction (func);
744
+ return swift::checkDistributedFunction (func);
748
745
}
749
746
747
+ // TODO(distributed): Remove this entirely and rely on generic signature and getConcrete to implement checks
750
748
llvm::SmallPtrSet<ProtocolDecl *, 2 >
751
749
swift::getDistributedSerializationRequirementProtocols (
752
750
NominalTypeDecl *nominal, ProtocolDecl *protocol) {
0 commit comments