Skip to content

Commit cfe2b3c

Browse files
committed
[Typed throws] Implement support for do throws(...) syntax
During the review of SE-0413, typed throws, the notion of a `do throws` syntax for `do..catch` blocks came up. Implement that syntax and semantics, as a way to explicitly specify the type of error that is thrown from the `do` body in `do..catch` statement.
1 parent c7c2f05 commit cfe2b3c

16 files changed

+206
-26
lines changed

include/swift/AST/CatchNode.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CatchNode: public llvm::PointerUnion<
3535
///
3636
/// Returns the thrown error type for a throwing context, or \c llvm::None
3737
/// if this is a non-throwing context.
38-
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
38+
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;
3939
};
4040

4141
} // end namespace swift

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,9 @@ ERROR(expected_catch_where_expr,PointsToFirstBadToken,
11961196
ERROR(docatch_not_trycatch,PointsToFirstBadToken,
11971197
"the 'do' keyword is used to specify a 'catch' region",
11981198
())
1199+
ERROR(do_throws_without_catch,none,
1200+
"a 'do' statement with a 'throws' clause must have at least one 'catch'",
1201+
())
11991202

12001203
// C-Style For Stmt
12011204
ERROR(c_style_for_stmt_removed,none,

include/swift/AST/Stmt.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "swift/AST/ConcreteDeclRef.h"
2525
#include "swift/AST/IfConfigClause.h"
2626
#include "swift/AST/TypeAlignments.h"
27+
#include "swift/AST/TypeLoc.h"
2728
#include "swift/AST/ThrownErrorDestination.h"
2829
#include "swift/Basic/Debug.h"
2930
#include "swift/Basic/NullablePtr.h"
@@ -1381,16 +1382,25 @@ class DoCatchStmt final
13811382
: public LabeledStmt,
13821383
private llvm::TrailingObjects<DoCatchStmt, CaseStmt *> {
13831384
friend TrailingObjects;
1385+
friend class DoCatchExplicitThrownTypeRequest;
13841386

13851387
SourceLoc DoLoc;
1388+
1389+
/// Location of the 'throws' token.
1390+
SourceLoc ThrowsLoc;
1391+
1392+
/// The error type that is being thrown.
1393+
TypeLoc ThrownType;
1394+
13861395
Stmt *Body;
13871396
ThrownErrorDestination RethrowDest;
13881397

1389-
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc, Stmt *body,
1398+
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc,
1399+
SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body,
13901400
ArrayRef<CaseStmt *> catches, llvm::Optional<bool> implicit)
13911401
: LabeledStmt(StmtKind::DoCatch, getDefaultImplicitFlag(implicit, doLoc),
13921402
labelInfo),
1393-
DoLoc(doLoc), Body(body) {
1403+
DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType), Body(body) {
13941404
Bits.DoCatchStmt.NumCatches = catches.size();
13951405
std::uninitialized_copy(catches.begin(), catches.end(),
13961406
getTrailingObjects<CaseStmt *>());
@@ -1400,15 +1410,28 @@ class DoCatchStmt final
14001410

14011411
public:
14021412
static DoCatchStmt *create(ASTContext &ctx, LabeledStmtInfo labelInfo,
1403-
SourceLoc doLoc, Stmt *body,
1413+
SourceLoc doLoc,
1414+
SourceLoc throwsLoc, TypeLoc thrownType,
1415+
Stmt *body,
14041416
ArrayRef<CaseStmt *> catches,
14051417
llvm::Optional<bool> implicit = llvm::None);
14061418

14071419
SourceLoc getDoLoc() const { return DoLoc; }
14081420

1421+
/// Retrieve the location of the 'throws' keyword, if present.
1422+
SourceLoc getThrowsLoc() const { return ThrowsLoc; }
1423+
14091424
SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); }
14101425
SourceLoc getEndLoc() const { return getCatches().back()->getEndLoc(); }
14111426

1427+
/// Retrieves the type representation for the thrown type.
1428+
TypeRepr *getThrownTypeRepr() const {
1429+
return ThrownType.getTypeRepr();
1430+
}
1431+
1432+
// Get the explicitly-specified thrown error type.
1433+
Type getExplicitlyThrownType(DeclContext *dc) const;
1434+
14121435
Stmt *getBody() const { return Body; }
14131436
void setBody(Stmt *s) { Body = s; }
14141437

@@ -1433,7 +1456,7 @@ class DoCatchStmt final
14331456
// and caught by the various 'catch' clauses. If this the catch clauses
14341457
// aren't exhausive, this is also the type of the error that is implicitly
14351458
// rethrown.
1436-
Type getCaughtErrorType() const;
1459+
Type getCaughtErrorType(DeclContext *dc) const;
14371460

14381461
/// Retrieves the rethrown error and its conversion to the error type
14391462
/// expected by the enclosing context.

include/swift/AST/TypeCheckRequests.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class ContextualPattern;
4646
class ContinueStmt;
4747
class DefaultArgumentExpr;
4848
class DefaultArgumentType;
49+
class DoCatchStmt;
4950
struct ExternalMacroDefinition;
5051
class ClosureExpr;
5152
class GenericParamList;
@@ -2303,6 +2304,27 @@ class ThrownTypeRequest
23032304
void cacheResult(Type value) const;
23042305
};
23052306

2307+
/// Determines the explicitly-written thrown error type in a do..catch block.
2308+
class DoCatchExplicitThrownTypeRequest
2309+
: public SimpleRequest<DoCatchExplicitThrownTypeRequest,
2310+
Type(DeclContext *, DoCatchStmt *),
2311+
RequestFlags::SeparatelyCached> {
2312+
public:
2313+
using SimpleRequest::SimpleRequest;
2314+
2315+
private:
2316+
friend SimpleRequest;
2317+
2318+
// Evaluation.
2319+
Type evaluate(Evaluator &evaluator, DeclContext *dc, DoCatchStmt *stmt) const;
2320+
2321+
public:
2322+
// Separate caching.
2323+
bool isCached() const;
2324+
llvm::Optional<Type> getCachedResult() const;
2325+
void cacheResult(Type value) const;
2326+
};
2327+
23062328
/// Determines the result type of a function or element type of a subscript.
23072329
class ResultTypeRequest
23082330
: public SimpleRequest<ResultTypeRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,9 @@ SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
360360
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
361361
SWIFT_REQUEST(TypeChecker, ThrownTypeRequest,
362362
Type(AbstractFunctionDecl *), SeparatelyCached, NoLocationInfo)
363+
SWIFT_REQUEST(TypeChecker, DoCatchExplicitThrownTypeRequest,
364+
Type(DeclContext *, DoCatchStmt *), SeparatelyCached,
365+
NoLocationInfo)
363366
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
364367
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
365368
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,

lib/AST/ASTVerifier.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,23 @@ class Verifier : public ASTWalker {
10201020
return shouldVerifyChecked(S->getSubExpr());
10211021
}
10221022

1023+
DeclContext *getInnermostDC() const {
1024+
for (auto scope : llvm::reverse(Scopes)) {
1025+
if (auto dc = scope.dyn_cast<DeclContext *>())
1026+
return dc;
1027+
}
1028+
1029+
return nullptr;
1030+
}
1031+
10231032
void verifyChecked(ThrowStmt *S) {
10241033
Type thrownError;
10251034
SourceLoc loc = S->getThrowLoc();
10261035
if (loc.isValid()) {
10271036
auto catchNode = ASTScope::lookupCatchNode(getModuleContext(), loc);
1028-
if (catchNode) {
1029-
if (auto thrown = catchNode.getThrownErrorTypeInContext(Ctx)) {
1037+
DeclContext *dc = getInnermostDC();
1038+
if (catchNode && dc) {
1039+
if (auto thrown = catchNode.getThrownErrorTypeInContext(dc)) {
10301040
thrownError = *thrown;
10311041
} else {
10321042
thrownError = Ctx.getNeverType();

lib/AST/Decl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11691,7 +11691,7 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
1169111691
}
1169211692

1169311693
llvm::Optional<Type>
11694-
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
11694+
CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
1169511695
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
1169611696
if (auto thrownError = func->getEffectiveThrownErrorType())
1169711697
return func->mapTypeIntoContext(*thrownError);
@@ -11708,13 +11708,13 @@ CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
1170811708
}
1170911709

1171011710
auto doCatch = get<DoCatchStmt *>();
11711-
if (auto thrownError = doCatch->getCaughtErrorType()) {
11711+
if (auto thrownError = doCatch->getCaughtErrorType(dc)) {
1171211712
if (thrownError->isNever())
1171311713
return llvm::None;
1171411714

1171511715
return thrownError;
1171611716
}
1171711717

1171811718
// If we haven't computed the error type yet, do so now.
11719-
return ctx.getErrorExistentialType();
11719+
return dc->getASTContext().getErrorExistentialType();
1172011720
}

lib/AST/Stmt.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,15 @@ Expr *ForEachStmt::getTypeCheckedSequence() const {
450450
}
451451

452452
DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
453-
SourceLoc doLoc, Stmt *body,
453+
SourceLoc doLoc,
454+
SourceLoc throwsLoc, TypeLoc thrownType,
455+
Stmt *body,
454456
ArrayRef<CaseStmt *> catches,
455457
llvm::Optional<bool> implicit) {
456458
void *mem = ctx.Allocate(totalSizeToAlloc<CaseStmt *>(catches.size()),
457459
alignof(DoCatchStmt));
458-
return ::new (mem) DoCatchStmt(labelInfo, doLoc, body, catches, implicit);
460+
return ::new (mem) DoCatchStmt(labelInfo, doLoc, throwsLoc, thrownType, body,
461+
catches, implicit);
459462
}
460463

461464
bool CaseLabelItem::isSyntacticallyExhaustive() const {
@@ -472,7 +475,17 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
472475
return false;
473476
}
474477

475-
Type DoCatchStmt::getCaughtErrorType() const {
478+
Type DoCatchStmt::getExplicitlyThrownType(DeclContext *dc) const {
479+
ASTContext &ctx = dc->getASTContext();
480+
DoCatchExplicitThrownTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
481+
return evaluateOrDefault(ctx.evaluator, request, Type());
482+
}
483+
484+
Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
485+
// Check for an explicitly-specified error type.
486+
if (Type explicitError = getExplicitlyThrownType(dc))
487+
return explicitError;
488+
476489
auto firstPattern = getCatches()
477490
.front()
478491
->getCaseLabelItems()

lib/AST/TypeCheckRequests.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,29 @@ void ThrownTypeRequest::cacheResult(Type type) const {
969969
func->ThrownType.setType(type);
970970
}
971971

972+
//----------------------------------------------------------------------------//
973+
// DoCatchExplicitThrownTypeRequest computation.
974+
//----------------------------------------------------------------------------//
975+
976+
bool DoCatchExplicitThrownTypeRequest::isCached() const {
977+
auto *const stmt = std::get<1>(getStorage());
978+
return stmt->getThrowsLoc().isValid();
979+
}
980+
981+
llvm::Optional<Type> DoCatchExplicitThrownTypeRequest::getCachedResult() const {
982+
auto *const stmt = std::get<1>(getStorage());
983+
Type thrownType = stmt->ThrownType.getType();
984+
if (thrownType.isNull())
985+
return llvm::None;
986+
987+
return thrownType;
988+
}
989+
990+
void DoCatchExplicitThrownTypeRequest::cacheResult(Type type) const {
991+
auto *const stmt = std::get<1>(getStorage());
992+
stmt->ThrownType.setType(type);
993+
}
994+
972995
//----------------------------------------------------------------------------//
973996
// ResultTypeRequest computation.
974997
//----------------------------------------------------------------------------//

lib/Parse/ParseStmt.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,8 +2191,8 @@ ParserResult<Stmt> Parser::parseStmtRepeat(LabeledStmtInfo labelInfo) {
21912191

21922192
///
21932193
/// stmt-do:
2194-
/// (identifier ':')? 'do' stmt-brace
2195-
/// (identifier ':')? 'do' stmt-brace stmt-catch+
2194+
/// (identifier ':')? 'do' throws-clause? stmt-brace
2195+
/// (identifier ':')? 'do' throws-clause? stmt-brace stmt-catch+
21962196
ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
21972197
bool shouldSkipDoTokenConsume) {
21982198
SourceLoc doLoc;
@@ -2205,6 +2205,25 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22052205

22062206
ParserStatus status;
22072207

2208+
// Parse the optional 'throws' clause.
2209+
SourceLoc throwsLoc;
2210+
TypeRepr *thrownType = nullptr;
2211+
if (consumeIf(tok::kw_throws, throwsLoc)) {
2212+
// Parse the thrown error type.
2213+
SourceLoc lParenLoc;
2214+
if (consumeIf(tok::l_paren, lParenLoc)) {
2215+
ParserResult<TypeRepr> parsedThrownTy =
2216+
parseType(diag::expected_thrown_error_type);
2217+
thrownType = parsedThrownTy.getPtrOrNull();
2218+
status |= parsedThrownTy;
2219+
2220+
SourceLoc rParenLoc;
2221+
parseMatchingToken(
2222+
tok::r_paren, rParenLoc,
2223+
diag::expected_rparen_after_thrown_error_type, lParenLoc);
2224+
}
2225+
}
2226+
22082227
ParserResult<BraceStmt> body =
22092228
parseBraceItemList(diag::expected_lbrace_after_do);
22102229
status |= body;
@@ -2236,7 +2255,12 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
22362255
}
22372256

22382257
return makeParserResult(status,
2239-
DoCatchStmt::create(Context, labelInfo, doLoc, body.get(), allClauses));
2258+
DoCatchStmt::create(Context, labelInfo, doLoc, throwsLoc, thrownType,
2259+
body.get(), allClauses));
2260+
}
2261+
2262+
if (throwsLoc.isValid()) {
2263+
diagnose(throwsLoc, diag::do_throws_without_catch);
22402264
}
22412265

22422266
// If we dont see a 'while' or see a 'while' that starts

lib/SILGen/SILGenStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,7 @@ void StmtEmitter::visitDoStmt(DoStmt *S) {
11171117
}
11181118

11191119
void StmtEmitter::visitDoCatchStmt(DoCatchStmt *S) {
1120-
Type formalExnType = S->getCaughtErrorType();
1120+
Type formalExnType = S->getCaughtErrorType(SGF.FunctionDC);
11211121
auto &exnTL = SGF.getTypeLowering(formalExnType);
11221122

11231123
SILValue exnArg;

lib/Sema/TypeCheckEffects.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,9 +2788,10 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
27882788
/// Retrieve the type of the error that can be caught when an error is
27892789
/// thrown from the given location.
27902790
Type getCaughtErrorTypeAt(SourceLoc loc) {
2791-
auto module = CurContext.getDeclContext()->getParentModule();
2791+
auto dc = CurContext.getDeclContext();
2792+
auto module = dc->getParentModule();
27922793
if (CatchNode catchNode = ASTScope::lookupCatchNode(module, loc)) {
2793-
if (auto caughtType = catchNode.getThrownErrorTypeInContext(Ctx))
2794+
if (auto caughtType = catchNode.getThrownErrorTypeInContext(dc))
27942795
return *caughtType;
27952796
}
27962797

@@ -2917,7 +2918,8 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
29172918
// specialized diagnostic about non-exhaustive catches.
29182919
if (!CurContext.handlesThrows(ConditionalEffectKind::Conditional)) {
29192920
CurContext.setNonExhaustiveCatch(true);
2920-
} else if (Type rethrownErrorType = S->getCaughtErrorType()) {
2921+
} else if (Type rethrownErrorType =
2922+
S->getCaughtErrorType(CurContext.getDeclContext())) {
29212923
// We're implicitly rethrowing the error out of this do..catch, so make
29222924
// sure that we can throw an error of this type out of this context.
29232925
auto catches = S->getCatches();
@@ -3554,15 +3556,23 @@ llvm::Optional<Type> TypeChecker::canThrow(ASTContext &ctx, Expr *expr) {
35543556
return classification.getThrownError();
35553557
}
35563558

3557-
Type TypeChecker::catchErrorType(ASTContext &ctx, DoCatchStmt *stmt) {
3559+
Type TypeChecker::catchErrorType(DeclContext *dc, DoCatchStmt *stmt) {
3560+
ASTContext &ctx = dc->getASTContext();
3561+
35583562
// When typed throws is disabled, this is always "any Error".
35593563
// FIXME: When we distinguish "precise" typed throws from normal typed
35603564
// throws, we'll be able to compute a more narrow catch error type in some
35613565
// case, e.g., from a `try` but not a `throws`.
35623566
if (!ctx.LangOpts.hasFeature(Feature::TypedThrows))
35633567
return ctx.getErrorExistentialType();
35643568

3565-
// Classify the throwing behavior of the "do" body.
3569+
// If the do..catch statement explicitly specifies that it throws, use
3570+
// that type.
3571+
if (Type explicitError = stmt->getExplicitlyThrownType(dc)) {
3572+
return explicitError;
3573+
}
3574+
3575+
// Otherwise, infer the thrown error type from the "do" body.
35663576
ApplyClassifier classifier(ctx);
35673577
Classification classification = classifier.classifyStmt(
35683578
stmt->getBody(), EffectKind::Throws);

lib/Sema/TypeCheckStmt.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,8 +1195,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
11951195
DC->getParentModule(), TS->getThrowLoc());
11961196
Type errorType;
11971197
if (catchNode) {
1198-
errorType = catchNode.getThrownErrorTypeInContext(getASTContext())
1199-
.value_or(Type());
1198+
errorType = catchNode.getThrownErrorTypeInContext(DC).value_or(Type());
12001199
}
12011200

12021201
// If there was no error type, use 'any Error'. We'll check it later.
@@ -1679,7 +1678,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
16791678
// Do-catch statements always limit exhaustivity checks.
16801679
bool limitExhaustivityChecks = true;
16811680

1682-
Type caughtErrorType = TypeChecker::catchErrorType(Ctx, S);
1681+
Type caughtErrorType = TypeChecker::catchErrorType(DC, S);
16831682
auto catches = S->getCatches();
16841683
checkSiblingCaseStmts(catches.begin(), catches.end(),
16851684
CaseParentKind::DoCatch, limitExhaustivityChecks,

0 commit comments

Comments
 (0)