Skip to content

[Typed throws] Implement support for do throws(...) syntax #70182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/CatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class CatchNode: public llvm::PointerUnion<
///
/// Returns the thrown error type for a throwing context, or \c llvm::None
/// if this is a non-throwing context.
llvm::Optional<Type> getThrownErrorTypeInContext(ASTContext &ctx) const;
llvm::Optional<Type> getThrownErrorTypeInContext(DeclContext *dc) const;
};

} // end namespace swift
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,9 @@ ERROR(expected_catch_where_expr,PointsToFirstBadToken,
ERROR(docatch_not_trycatch,PointsToFirstBadToken,
"the 'do' keyword is used to specify a 'catch' region",
())
ERROR(do_throws_without_catch,none,
"a 'do' statement with a 'throws' clause must have at least one 'catch'",
())

// C-Style For Stmt
ERROR(c_style_for_stmt_removed,none,
Expand Down
31 changes: 27 additions & 4 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "swift/AST/ConcreteDeclRef.h"
#include "swift/AST/IfConfigClause.h"
#include "swift/AST/TypeAlignments.h"
#include "swift/AST/TypeLoc.h"
#include "swift/AST/ThrownErrorDestination.h"
#include "swift/Basic/Debug.h"
#include "swift/Basic/NullablePtr.h"
Expand Down Expand Up @@ -1381,16 +1382,25 @@ class DoCatchStmt final
: public LabeledStmt,
private llvm::TrailingObjects<DoCatchStmt, CaseStmt *> {
friend TrailingObjects;
friend class DoCatchExplicitThrownTypeRequest;

SourceLoc DoLoc;

/// Location of the 'throws' token.
SourceLoc ThrowsLoc;

/// The error type that is being thrown.
TypeLoc ThrownType;

Stmt *Body;
ThrownErrorDestination RethrowDest;

DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc, Stmt *body,
DoCatchStmt(LabeledStmtInfo labelInfo, SourceLoc doLoc,
SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body,
ArrayRef<CaseStmt *> catches, llvm::Optional<bool> implicit)
: LabeledStmt(StmtKind::DoCatch, getDefaultImplicitFlag(implicit, doLoc),
labelInfo),
DoLoc(doLoc), Body(body) {
DoLoc(doLoc), ThrowsLoc(throwsLoc), ThrownType(thrownType), Body(body) {
Bits.DoCatchStmt.NumCatches = catches.size();
std::uninitialized_copy(catches.begin(), catches.end(),
getTrailingObjects<CaseStmt *>());
Expand All @@ -1400,15 +1410,28 @@ class DoCatchStmt final

public:
static DoCatchStmt *create(ASTContext &ctx, LabeledStmtInfo labelInfo,
SourceLoc doLoc, Stmt *body,
SourceLoc doLoc,
SourceLoc throwsLoc, TypeLoc thrownType,
Stmt *body,
ArrayRef<CaseStmt *> catches,
llvm::Optional<bool> implicit = llvm::None);

SourceLoc getDoLoc() const { return DoLoc; }

/// Retrieve the location of the 'throws' keyword, if present.
SourceLoc getThrowsLoc() const { return ThrowsLoc; }

SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(DoLoc); }
SourceLoc getEndLoc() const { return getCatches().back()->getEndLoc(); }

/// Retrieves the type representation for the thrown type.
TypeRepr *getThrownTypeRepr() const {
return ThrownType.getTypeRepr();
}

// Get the explicitly-specified thrown error type.
Type getExplicitlyThrownType(DeclContext *dc) const;

Stmt *getBody() const { return Body; }
void setBody(Stmt *s) { Body = s; }

Expand All @@ -1433,7 +1456,7 @@ class DoCatchStmt final
// and caught by the various 'catch' clauses. If this the catch clauses
// aren't exhausive, this is also the type of the error that is implicitly
// rethrown.
Type getCaughtErrorType() const;
Type getCaughtErrorType(DeclContext *dc) const;

/// Retrieves the rethrown error and its conversion to the error type
/// expected by the enclosing context.
Expand Down
22 changes: 22 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ContextualPattern;
class ContinueStmt;
class DefaultArgumentExpr;
class DefaultArgumentType;
class DoCatchStmt;
struct ExternalMacroDefinition;
class ClosureExpr;
class GenericParamList;
Expand Down Expand Up @@ -2303,6 +2304,27 @@ class ThrownTypeRequest
void cacheResult(Type value) const;
};

/// Determines the explicitly-written thrown error type in a do..catch block.
class DoCatchExplicitThrownTypeRequest
: public SimpleRequest<DoCatchExplicitThrownTypeRequest,
Type(DeclContext *, DoCatchStmt *),
RequestFlags::SeparatelyCached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
Type evaluate(Evaluator &evaluator, DeclContext *dc, DoCatchStmt *stmt) const;

public:
// Separate caching.
bool isCached() const;
llvm::Optional<Type> getCachedResult() const;
void cacheResult(Type value) const;
};

/// Determines the result type of a function or element type of a subscript.
class ResultTypeRequest
: public SimpleRequest<ResultTypeRequest,
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ SWIFT_REQUEST(TypeChecker, ParamSpecifierRequest,
ParamDecl::Specifier(ParamDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ThrownTypeRequest,
Type(AbstractFunctionDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DoCatchExplicitThrownTypeRequest,
Type(DeclContext *, DoCatchStmt *), SeparatelyCached,
NoLocationInfo)
SWIFT_REQUEST(TypeChecker, ResultTypeRequest,
Type(ValueDecl *), SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, AreAllStoredPropertiesDefaultInitableRequest,
Expand Down
14 changes: 12 additions & 2 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,23 @@ class Verifier : public ASTWalker {
return shouldVerifyChecked(S->getSubExpr());
}

DeclContext *getInnermostDC() const {
for (auto scope : llvm::reverse(Scopes)) {
if (auto dc = scope.dyn_cast<DeclContext *>())
return dc;
}

return nullptr;
}

void verifyChecked(ThrowStmt *S) {
Type thrownError;
SourceLoc loc = S->getThrowLoc();
if (loc.isValid()) {
auto catchNode = ASTScope::lookupCatchNode(getModuleContext(), loc);
if (catchNode) {
if (auto thrown = catchNode.getThrownErrorTypeInContext(Ctx)) {
DeclContext *dc = getInnermostDC();
if (catchNode && dc) {
if (auto thrown = catchNode.getThrownErrorTypeInContext(dc)) {
thrownError = *thrown;
} else {
thrownError = Ctx.getNeverType();
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11691,7 +11691,7 @@ MacroDiscriminatorContext::getParentOf(FreestandingMacroExpansion *expansion) {
}

llvm::Optional<Type>
CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
CatchNode::getThrownErrorTypeInContext(DeclContext *dc) const {
if (auto func = dyn_cast<AbstractFunctionDecl *>()) {
if (auto thrownError = func->getEffectiveThrownErrorType())
return func->mapTypeIntoContext(*thrownError);
Expand All @@ -11708,13 +11708,13 @@ CatchNode::getThrownErrorTypeInContext(ASTContext &ctx) const {
}

auto doCatch = get<DoCatchStmt *>();
if (auto thrownError = doCatch->getCaughtErrorType()) {
if (auto thrownError = doCatch->getCaughtErrorType(dc)) {
if (thrownError->isNever())
return llvm::None;

return thrownError;
}

// If we haven't computed the error type yet, do so now.
return ctx.getErrorExistentialType();
return dc->getASTContext().getErrorExistentialType();
}
19 changes: 16 additions & 3 deletions lib/AST/Stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,15 @@ Expr *ForEachStmt::getTypeCheckedSequence() const {
}

DoCatchStmt *DoCatchStmt::create(ASTContext &ctx, LabeledStmtInfo labelInfo,
SourceLoc doLoc, Stmt *body,
SourceLoc doLoc,
SourceLoc throwsLoc, TypeLoc thrownType,
Stmt *body,
ArrayRef<CaseStmt *> catches,
llvm::Optional<bool> implicit) {
void *mem = ctx.Allocate(totalSizeToAlloc<CaseStmt *>(catches.size()),
alignof(DoCatchStmt));
return ::new (mem) DoCatchStmt(labelInfo, doLoc, body, catches, implicit);
return ::new (mem) DoCatchStmt(labelInfo, doLoc, throwsLoc, thrownType, body,
catches, implicit);
}

bool CaseLabelItem::isSyntacticallyExhaustive() const {
Expand All @@ -472,7 +475,17 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const {
return false;
}

Type DoCatchStmt::getCaughtErrorType() const {
Type DoCatchStmt::getExplicitlyThrownType(DeclContext *dc) const {
ASTContext &ctx = dc->getASTContext();
DoCatchExplicitThrownTypeRequest request{dc, const_cast<DoCatchStmt *>(this)};
return evaluateOrDefault(ctx.evaluator, request, Type());
}

Type DoCatchStmt::getCaughtErrorType(DeclContext *dc) const {
// Check for an explicitly-specified error type.
if (Type explicitError = getExplicitlyThrownType(dc))
return explicitError;

auto firstPattern = getCatches()
.front()
->getCaseLabelItems()
Expand Down
23 changes: 23 additions & 0 deletions lib/AST/TypeCheckRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,29 @@ void ThrownTypeRequest::cacheResult(Type type) const {
func->ThrownType.setType(type);
}

//----------------------------------------------------------------------------//
// DoCatchExplicitThrownTypeRequest computation.
//----------------------------------------------------------------------------//

bool DoCatchExplicitThrownTypeRequest::isCached() const {
auto *const stmt = std::get<1>(getStorage());
return stmt->getThrowsLoc().isValid();
}

llvm::Optional<Type> DoCatchExplicitThrownTypeRequest::getCachedResult() const {
auto *const stmt = std::get<1>(getStorage());
Type thrownType = stmt->ThrownType.getType();
if (thrownType.isNull())
return llvm::None;

return thrownType;
}

void DoCatchExplicitThrownTypeRequest::cacheResult(Type type) const {
auto *const stmt = std::get<1>(getStorage());
stmt->ThrownType.setType(type);
}

//----------------------------------------------------------------------------//
// ResultTypeRequest computation.
//----------------------------------------------------------------------------//
Expand Down
30 changes: 27 additions & 3 deletions lib/Parse/ParseStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2191,8 +2191,8 @@ ParserResult<Stmt> Parser::parseStmtRepeat(LabeledStmtInfo labelInfo) {

///
/// stmt-do:
/// (identifier ':')? 'do' stmt-brace
/// (identifier ':')? 'do' stmt-brace stmt-catch+
/// (identifier ':')? 'do' throws-clause? stmt-brace
/// (identifier ':')? 'do' throws-clause? stmt-brace stmt-catch+
ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
bool shouldSkipDoTokenConsume) {
SourceLoc doLoc;
Expand All @@ -2205,6 +2205,25 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,

ParserStatus status;

// Parse the optional 'throws' clause.
SourceLoc throwsLoc;
TypeRepr *thrownType = nullptr;
if (consumeIf(tok::kw_throws, throwsLoc)) {
// Parse the thrown error type.
SourceLoc lParenLoc;
if (consumeIf(tok::l_paren, lParenLoc)) {
ParserResult<TypeRepr> parsedThrownTy =
parseType(diag::expected_thrown_error_type);
thrownType = parsedThrownTy.getPtrOrNull();
status |= parsedThrownTy;

SourceLoc rParenLoc;
parseMatchingToken(
tok::r_paren, rParenLoc,
diag::expected_rparen_after_thrown_error_type, lParenLoc);
}
}

ParserResult<BraceStmt> body =
parseBraceItemList(diag::expected_lbrace_after_do);
status |= body;
Expand Down Expand Up @@ -2236,7 +2255,12 @@ ParserResult<Stmt> Parser::parseStmtDo(LabeledStmtInfo labelInfo,
}

return makeParserResult(status,
DoCatchStmt::create(Context, labelInfo, doLoc, body.get(), allClauses));
DoCatchStmt::create(Context, labelInfo, doLoc, throwsLoc, thrownType,
body.get(), allClauses));
}

if (throwsLoc.isValid()) {
diagnose(throwsLoc, diag::do_throws_without_catch);
}

// If we dont see a 'while' or see a 'while' that starts
Expand Down
2 changes: 1 addition & 1 deletion lib/SILGen/SILGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ void StmtEmitter::visitDoStmt(DoStmt *S) {
}

void StmtEmitter::visitDoCatchStmt(DoCatchStmt *S) {
Type formalExnType = S->getCaughtErrorType();
Type formalExnType = S->getCaughtErrorType(SGF.FunctionDC);
auto &exnTL = SGF.getTypeLowering(formalExnType);

SILValue exnArg;
Expand Down
20 changes: 15 additions & 5 deletions lib/Sema/TypeCheckEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2788,9 +2788,10 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
/// Retrieve the type of the error that can be caught when an error is
/// thrown from the given location.
Type getCaughtErrorTypeAt(SourceLoc loc) {
auto module = CurContext.getDeclContext()->getParentModule();
auto dc = CurContext.getDeclContext();
auto module = dc->getParentModule();
if (CatchNode catchNode = ASTScope::lookupCatchNode(module, loc)) {
if (auto caughtType = catchNode.getThrownErrorTypeInContext(Ctx))
if (auto caughtType = catchNode.getThrownErrorTypeInContext(dc))
return *caughtType;
}

Expand Down Expand Up @@ -2917,7 +2918,8 @@ class CheckEffectsCoverage : public EffectsHandlingWalker<CheckEffectsCoverage>
// specialized diagnostic about non-exhaustive catches.
if (!CurContext.handlesThrows(ConditionalEffectKind::Conditional)) {
CurContext.setNonExhaustiveCatch(true);
} else if (Type rethrownErrorType = S->getCaughtErrorType()) {
} else if (Type rethrownErrorType =
S->getCaughtErrorType(CurContext.getDeclContext())) {
// We're implicitly rethrowing the error out of this do..catch, so make
// sure that we can throw an error of this type out of this context.
auto catches = S->getCatches();
Expand Down Expand Up @@ -3554,15 +3556,23 @@ llvm::Optional<Type> TypeChecker::canThrow(ASTContext &ctx, Expr *expr) {
return classification.getThrownError();
}

Type TypeChecker::catchErrorType(ASTContext &ctx, DoCatchStmt *stmt) {
Type TypeChecker::catchErrorType(DeclContext *dc, DoCatchStmt *stmt) {
ASTContext &ctx = dc->getASTContext();

// When typed throws is disabled, this is always "any Error".
// FIXME: When we distinguish "precise" typed throws from normal typed
// throws, we'll be able to compute a more narrow catch error type in some
// case, e.g., from a `try` but not a `throws`.
if (!ctx.LangOpts.hasFeature(Feature::TypedThrows))
return ctx.getErrorExistentialType();

// Classify the throwing behavior of the "do" body.
// If the do..catch statement explicitly specifies that it throws, use
// that type.
if (Type explicitError = stmt->getExplicitlyThrownType(dc)) {
return explicitError;
}

// Otherwise, infer the thrown error type from the "do" body.
ApplyClassifier classifier(ctx);
Classification classification = classifier.classifyStmt(
stmt->getBody(), EffectKind::Throws);
Expand Down
5 changes: 2 additions & 3 deletions lib/Sema/TypeCheckStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
DC->getParentModule(), TS->getThrowLoc());
Type errorType;
if (catchNode) {
errorType = catchNode.getThrownErrorTypeInContext(getASTContext())
.value_or(Type());
errorType = catchNode.getThrownErrorTypeInContext(DC).value_or(Type());
}

// If there was no error type, use 'any Error'. We'll check it later.
Expand Down Expand Up @@ -1679,7 +1678,7 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
// Do-catch statements always limit exhaustivity checks.
bool limitExhaustivityChecks = true;

Type caughtErrorType = TypeChecker::catchErrorType(Ctx, S);
Type caughtErrorType = TypeChecker::catchErrorType(DC, S);
auto catches = S->getCatches();
checkSiblingCaseStmts(catches.begin(), catches.end(),
CaseParentKind::DoCatch, limitExhaustivityChecks,
Expand Down
Loading