Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ extension TypeSyntaxProtocol {
.contains(.keyword(.some))
}

/// Whether or not this type is `any T` or a type derived from such a type.
var isAny: Bool {
tokens(viewMode: .fixedUp).lazy
.map(\.tokenKind)
.contains(.keyword(.any))
}

/// Check whether or not this type is named with the specified name and
/// module.
///
Expand Down
203 changes: 159 additions & 44 deletions Sources/TestingMacros/Support/ClosureCaptureListParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,62 +50,177 @@ struct CapturedValueInfo {
return
}

// Potentially get the name of the type comprising the current lexical
// context (i.e. whatever `Self` is.)
lazy var lexicalContext = context.lexicalContext
lazy var typeNameOfLexicalContext = {
let lexicalContext = lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
return context.type(ofLexicalContext: lexicalContext)
}()
if let (expr, type) = Self._inferExpressionAndType(of: capture, in: context) {
self.expression = expr
self.type = type
} else {
// Not enough contextual information to derive the type here.
context.diagnose(.typeOfCaptureIsAmbiguous(capture))
}
}

/// Infer the captured expression and the type of a closure capture list item.
///
/// - Parameters:
/// - capture: The closure capture list item to inspect.
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: A tuple containing the expression and type of `capture`, or
/// `nil` if they could not be inferred.
private static func _inferExpressionAndType(of capture: ClosureCaptureSyntax, in context: some MacroExpansionContext) -> (ExprSyntax, TypeSyntax)? {
if let initializer = capture.initializer {
// Found an initializer clause. Extract the expression it captures.
self.expression = removeParentheses(from: initializer.value) ?? initializer.value
let finder = _ExprTypeFinder(in: context)
finder.walk(initializer.value)
if let inferredType = finder.inferredType {
return (initializer.value, inferredType)
}
} else if capture.name.tokenKind == .keyword(.self),
let typeNameOfLexicalContext = Self._inferSelf(from: context) {
// Capturing self.
return (ExprSyntax(DeclReferenceExprSyntax(baseName: .keyword(.self))), typeNameOfLexicalContext)
} else if let parameterType = Self._findTypeOfParameter(named: capture.name, in: context.lexicalContext) {
return (ExprSyntax(DeclReferenceExprSyntax(baseName: capture.name.trimmed)), parameterType)
}

return nil
}

private final class _ExprTypeFinder<C>: SyntaxAnyVisitor where C: MacroExpansionContext {
var context: C

/// The type that was inferred from the visited syntax tree, if any.
///
/// This type has not been fixed up yet. Use ``inferredType`` for the final
/// derived type.
private var _inferredType: TypeSyntax?

/// Whether or not the inferred type has been made optional by e.g. `try?`.
private var _needsOptionalApplied = false

/// The type that was inferred from the visited syntax tree, if any.
var inferredType: TypeSyntax? {
_inferredType.flatMap { inferredType in
if inferredType.isSome || inferredType.isAny {
// `some` and `any` types are not concrete and cannot be inferred.
nil
} else if _needsOptionalApplied {
TypeSyntax(OptionalTypeSyntax(wrappedType: inferredType.trimmed))
} else {
inferredType
}
}
}

init(in context: C) {
self.context = context
super.init(viewMode: .sourceAccurate)
}

override func visitAny(_ node: Syntax) -> SyntaxVisitorContinueKind {
if inferredType != nil {
// Another part of the syntax tree has already provided a type. Stop.
return .skipChildren
}

// Find the 'as' clause so we can determine the type of the captured value.
if let asExpr = self.expression.as(AsExprSyntax.self) {
self.type = if asExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
switch node.kind {
case .asExpr:
let asExpr = node.cast(AsExprSyntax.self)
if let type = asExpr.type.as(IdentifierTypeSyntax.self), type.name.tokenKind == .keyword(.Self) {
// `Self` should resolve to the lexical context's type.
_inferredType = CapturedValueInfo._inferSelf(from: context)
} else if asExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
// If the caller is using as?, make the type optional.
TypeSyntax(OptionalTypeSyntax(wrappedType: asExpr.type.trimmed))
_inferredType = TypeSyntax(OptionalTypeSyntax(wrappedType: asExpr.type.trimmed))
} else {
asExpr.type
_inferredType = asExpr.type
}
} else if let selfExpr = self.expression.as(DeclReferenceExprSyntax.self),
selfExpr.baseName.tokenKind == .keyword(.self),
selfExpr.argumentNames == nil,
let typeNameOfLexicalContext {
// Copying self.
self.type = typeNameOfLexicalContext
} else {
// Handle literals. Any other types are ambiguous.
switch self.expression.kind {
case .integerLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("IntegerLiteralType")))
case .floatLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("FloatLiteralType")))
case .booleanLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("BooleanLiteralType")))
case .stringLiteralExpr, .simpleStringLiteralExpr:
self.type = TypeSyntax(IdentifierTypeSyntax(name: .identifier("StringLiteralType")))
default:
context.diagnose(.typeOfCaptureIsAmbiguous(capture, initializedWith: initializer))
return .skipChildren

case .awaitExpr, .unsafeExpr:
// These effect keywords do not affect the type of the expression.
return .visitChildren

case .tryExpr:
let tryExpr = node.cast(TryExprSyntax.self)
if tryExpr.questionOrExclamationMark?.tokenKind == .postfixQuestionMark {
// The resulting type from the inner expression will be optionalized.
_needsOptionalApplied = true
}
}
return .visitChildren

} else if capture.name.tokenKind == .keyword(.self),
let typeNameOfLexicalContext {
// Capturing self.
self.expression = "self"
self.type = typeNameOfLexicalContext
} else if let parameterType = Self._findTypeOfParameter(named: capture.name, in: lexicalContext) {
self.expression = ExprSyntax(DeclReferenceExprSyntax(baseName: capture.name.trimmed))
self.type = parameterType
} else {
// Not enough contextual information to derive the type here.
context.diagnose(.typeOfCaptureIsAmbiguous(capture))
case .tupleExpr:
// If the tuple contains exactly one element, it's just parentheses
// around that expression.
let tupleExpr = node.cast(TupleExprSyntax.self)
if tupleExpr.elements.count == 1 {
return .visitChildren
}

// Otherwise, we need to try to compose the type as a tuple type from
// the types of all elements in the tuple expression. Note that tuples
// do not conform to Sendable or Codable, so our current use of this
// code in exit tests will still diagnose an error, but the error ("must
// conform") will be more useful than "couldn't infer".
let elements = tupleExpr.elements.compactMap { element in
let finder = Self(in: context)
finder.walk(element.expression)
return finder.inferredType.map { type in
TupleTypeElementSyntax(firstName: element.label?.trimmed, type: type.trimmed)
}
}
if elements.count == tupleExpr.elements.count {
_inferredType = TypeSyntax(
TupleTypeSyntax(elements: TupleTypeElementListSyntax { elements })
)
}
return .skipChildren

case .declReferenceExpr:
// If the reference is to `self` without any arguments, its type can be
// inferred from the lexical context.
let expr = node.cast(DeclReferenceExprSyntax.self)
if expr.baseName.tokenKind == .keyword(.self), expr.argumentNames == nil {
_inferredType = CapturedValueInfo._inferSelf(from: context)
}
return .skipChildren

case .integerLiteralExpr:
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("IntegerLiteralType")))
return .skipChildren

case .floatLiteralExpr:
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("FloatLiteralType")))
return .skipChildren

case .booleanLiteralExpr:
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("BooleanLiteralType")))
return .skipChildren

case .stringLiteralExpr, .simpleStringLiteralExpr:
_inferredType = TypeSyntax(IdentifierTypeSyntax(name: .identifier("StringLiteralType")))
return .skipChildren

default:
// We don't know how to infer a type from this syntax node, so do not
// proceed further.
return .skipChildren
}
}
}

/// Get the type of `self` inferred from the given context.
///
/// - Parameters:
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: The type in `lexicalContext` corresponding to `Self`, or `nil`
/// if it could not be determined.
private static func _inferSelf(from context: some MacroExpansionContext) -> TypeSyntax? {
let lexicalContext = context.lexicalContext.drop { !$0.isProtocol((any DeclGroupSyntax).self) }
return context.type(ofLexicalContext: lexicalContext)
}

/// Find a function or closure parameter in the given lexical context with a
/// given name and return its type.
///
Expand Down
4 changes: 4 additions & 0 deletions Tests/TestingMacrosTests/ConditionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ struct ConditionMacroTests {
"Type of captured value 'a' is ambiguous",
"#expectExitTest(processExitsWith: x) { [a = b] in }":
"Type of captured value 'a' is ambiguous",
"#expectExitTest(processExitsWith: x) { [a = b as any T] in }":
"Type of captured value 'a' is ambiguous",
"#expectExitTest(processExitsWith: x) { [a = b as some T] in }":
"Type of captured value 'a' is ambiguous",
"struct S<T> { func f() { #expectExitTest(processExitsWith: x) { [a] in } } }":
"Cannot call macro ''#expectExitTest(processExitsWith:_:)'' within generic structure 'S'",
]
Expand Down
38 changes: 37 additions & 1 deletion Tests/TestingTests/ExitTestTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,10 @@ private import _TestingInternals

@Test("self in capture list")
func captureListWithSelf() async {
await #expect(processExitsWith: .success) { [self, x = self] in
await #expect(processExitsWith: .success) { [self, x = self, y = self as Self] in
#expect(self.property == 456)
#expect(x.property == 456)
#expect(y.property == 456)
}
}
}
Expand Down Expand Up @@ -506,6 +507,41 @@ private import _TestingInternals
}
}

@Test("Capturing an optional value")
func captureListWithOptionalValue() async throws {
await #expect(processExitsWith: .success) { [x = nil as Int?] in
#expect(x != 1)
}
await #expect(processExitsWith: .success) { [x = (0 as Any) as? String] in
#expect(x == nil)
}
}

@Test("Capturing an effectful expression")
func captureListWithEffectfulExpression() async throws {
func f() async throws -> Int { 0 }
try await #require(processExitsWith: .success) { [f = try await f() as Int] in
#expect(f == 0)
}
try await #expect(processExitsWith: .success) { [f = f() as Int] in
#expect(f == 0)
}
}

#if false // intentionally fails to compile
@Test("Capturing a tuple")
func captureListWithTuple() async throws {
// A tuple whose elements conform to Codable does not itself conform to
// Codable, so we cannot actually express this capture list in a way that
// works with #expect().
await #expect(processExitsWith: .success) { [x = (0 as Int, 1 as Double, "2" as String)] in
#expect(x.0 == 0)
#expect(x.1 == 1)
#expect(x.2 == "2")
}
}
#endif

#if false // intentionally fails to compile
struct NonCodableValue {}

Expand Down