Skip to content

Commit dfa262c

Browse files
SupunSturbolent
authored andcommitted
Refactor reference creation
1 parent e87a464 commit dfa262c

6 files changed

+103
-107
lines changed

interpreter/errors.go

+17
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,23 @@ func (e NestedReferenceError) Error() string {
10671067
)
10681068
}
10691069

1070+
// NonOptionalReferenceToNilError
1071+
type NonOptionalReferenceToNilError struct {
1072+
ReferenceType sema.Type
1073+
LocationRange
1074+
}
1075+
1076+
var _ errors.UserError = NonOptionalReferenceToNilError{}
1077+
1078+
func (NonOptionalReferenceToNilError) IsUserError() {}
1079+
1080+
func (e NonOptionalReferenceToNilError) Error() string {
1081+
return fmt.Sprintf(
1082+
"cannot create a reference to nil: expected `%s`, but found `nil`",
1083+
e.ReferenceType.ID(),
1084+
)
1085+
}
1086+
10701087
// InclusiveRangeConstructionError
10711088

10721089
type InclusiveRangeConstructionError struct {

interpreter/interpreter_expression.go

+61-64
Original file line numberDiff line numberDiff line change
@@ -309,42 +309,7 @@ func (interpreter *Interpreter) memberExpressionGetterSetter(
309309
// e.g.1: Given type T, this method returns &T.
310310
// e.g.2: Given T?, this returns (&T)?
311311
func (interpreter *Interpreter) getReferenceValue(value Value, resultType sema.Type, locationRange LocationRange) Value {
312-
313-
// `resultType` is always an [optional] reference.
314-
// This is guaranteed by the checker.
315-
referenceType, ok := sema.UnwrapOptionalType(resultType).(*sema.ReferenceType)
316-
if !ok {
317-
panic(errors.NewUnreachableError())
318-
}
319-
320-
switch value := value.(type) {
321-
case NilValue, ReferenceValue:
322-
// Reference to a nil, should return a nil.
323-
// If the value is already a reference then return the same reference.
324-
// However, we need to make sure that this reference is actually a subtype of the resultType,
325-
// since the checker may not be aware that we are "short-circuiting" in this case
326-
// Additionally, it is only safe to "compress" reference types like this when the desired
327-
// result reference type is unauthorized
328-
329-
staticType := value.StaticType(interpreter)
330-
if referenceType.Authorization != sema.UnauthorizedAccess || !interpreter.IsSubTypeOfSemaType(staticType, resultType) {
331-
panic(InvalidMemberReferenceError{
332-
ExpectedType: resultType,
333-
ActualType: interpreter.MustConvertStaticToSemaType(staticType),
334-
LocationRange: locationRange,
335-
})
336-
}
337-
338-
return value
339-
340-
case *SomeValue:
341-
innerValue := interpreter.getReferenceValue(value.value, resultType, locationRange)
342-
return NewSomeValueNonCopying(interpreter, innerValue)
343-
}
344-
345-
auth := ConvertSemaAccessToStaticAuthorization(interpreter, referenceType.Authorization)
346-
347-
return NewEphemeralReferenceValue(interpreter, auth, value, referenceType.Type, locationRange)
312+
return interpreter.createReference(resultType, value, locationRange, true)
348313
}
349314

350315
func (interpreter *Interpreter) checkMemberAccess(
@@ -1466,20 +1431,30 @@ func (interpreter *Interpreter) VisitReferenceExpression(referenceExpression *as
14661431

14671432
result := interpreter.evalExpression(referenceExpression.Expression)
14681433

1469-
return interpreter.createReference(borrowType, result, referenceExpression)
1434+
locationRange := LocationRange{
1435+
Location: interpreter.Location,
1436+
HasPosition: referenceExpression,
1437+
}
1438+
1439+
return interpreter.createReference(borrowType, result, locationRange, false)
14701440
}
14711441

14721442
func (interpreter *Interpreter) createReference(
14731443
borrowType sema.Type,
14741444
value Value,
1475-
hasPosition ast.HasPosition,
1445+
locationRange LocationRange,
1446+
isImplicit bool,
14761447
) Value {
14771448

14781449
// There are four potential cases:
1479-
// 1) Target type is optional, actual value is also optional (nil/SomeValue)
1480-
// 2) Target type is optional, actual value is non-optional
1481-
// 3) Target type is non-optional, actual value is optional (SomeValue)
1482-
// 4) Target type is non-optional, actual value is non-optional
1450+
// (1) Target type is optional, actual value is also optional
1451+
// (1.a) value is SomeValue
1452+
// (1.b) value is nil
1453+
// (2) Target type is optional, actual value is non-optional
1454+
// (3) Target type is non-optional, actual value is optional
1455+
// (3.a) value is SomeValue
1456+
// (3.b) value is nil
1457+
// (4) Target type is non-optional, actual value is non-optional
14831458

14841459
switch typ := borrowType.(type) {
14851460
case *sema.OptionalType:
@@ -1489,48 +1464,75 @@ func (interpreter *Interpreter) createReference(
14891464
switch value := value.(type) {
14901465
case *SomeValue:
14911466
// Case (1):
1492-
// References to optionals are transformed into optional references,
1493-
// so move the *SomeValue out to the reference itself
1467+
// References to optionals are transformed into optional references.
14941468

1495-
locationRange := LocationRange{
1496-
Location: interpreter.Location,
1497-
HasPosition: hasPosition,
1498-
}
1469+
// (1.a) value is SomeValue
1470+
// Move the *SomeValue out to the reference itself.
14991471

15001472
innerValue := value.InnerValue(interpreter, locationRange)
15011473

1502-
referenceValue := interpreter.createReference(innerType, innerValue, hasPosition)
1474+
referenceValue := interpreter.createReference(innerType, innerValue, locationRange, false)
15031475

15041476
// Wrap the reference with an optional (since an optional is expected).
15051477
return NewSomeValueNonCopying(interpreter, referenceValue)
15061478

15071479
case NilValue:
1480+
// Case (1.b) value is nil.
1481+
// Since the resulting type can accommodate a nil (optional-reference), return nil,
15081482
return Nil
15091483

15101484
default:
15111485
// Case (2):
15121486
// If the referenced value is non-optional,
15131487
// but the target type is optional.
1514-
referenceValue := interpreter.createReference(innerType, value, hasPosition)
1488+
referenceValue := interpreter.createReference(innerType, value, locationRange, false)
15151489

15161490
// Wrap the reference with an optional (since an optional is expected).
15171491
return NewSomeValueNonCopying(interpreter, referenceValue)
15181492
}
15191493

15201494
case *sema.ReferenceType:
1521-
// Case (3): target type is non-optional, actual value is optional.
1522-
if someValue, ok := value.(*SomeValue); ok {
1523-
locationRange := LocationRange{
1524-
Location: interpreter.Location,
1525-
HasPosition: hasPosition,
1495+
1496+
switch value := value.(type) {
1497+
case *SomeValue:
1498+
// Case (3.a): target type is non-optional, actual value is optional.
1499+
innerValue := value.InnerValue(interpreter, locationRange)
1500+
1501+
return interpreter.createReference(typ, innerValue, locationRange, false)
1502+
1503+
case NilValue:
1504+
// Case (3.b) value is nil.
1505+
// Since the resulting type can NOT accommodate a nil (non-optional reference), error-out.
1506+
panic(NonOptionalReferenceToNilError{
1507+
ReferenceType: typ,
1508+
LocationRange: locationRange,
1509+
})
1510+
1511+
case ReferenceValue:
1512+
if isImplicit {
1513+
// During implicit reference creation (e.g: member/index access on a reference),
1514+
// if the value is already a reference then return the same reference.
1515+
// However, we need to make sure that this reference is actually a subtype of the resultType,
1516+
// since the checker may not be aware that we are "short-circuiting" in this case.
1517+
// Additionally, it is only safe to "compress" reference types like this when the desired
1518+
// result reference type is unauthorized
1519+
staticType := value.StaticType(interpreter)
1520+
if typ.Authorization != sema.UnauthorizedAccess || !interpreter.IsSubTypeOfSemaType(staticType, typ) {
1521+
panic(InvalidMemberReferenceError{
1522+
ExpectedType: typ,
1523+
ActualType: interpreter.MustConvertStaticToSemaType(staticType),
1524+
LocationRange: locationRange,
1525+
})
1526+
}
1527+
1528+
return value
15261529
}
1527-
innerValue := someValue.InnerValue(interpreter, locationRange)
15281530

1529-
return interpreter.createReference(typ, innerValue, hasPosition)
1531+
// break
15301532
}
15311533

15321534
// Case (4): target type is non-optional, actual value is also non-optional.
1533-
return interpreter.newEphemeralReference(value, typ, hasPosition)
1535+
return interpreter.newEphemeralReference(value, typ, locationRange)
15341536

15351537
default:
15361538
panic(errors.NewUnreachableError())
@@ -1540,17 +1542,12 @@ func (interpreter *Interpreter) createReference(
15401542
func (interpreter *Interpreter) newEphemeralReference(
15411543
value Value,
15421544
typ *sema.ReferenceType,
1543-
hasPosition ast.HasPosition,
1545+
locationRange LocationRange,
15441546
) *EphemeralReferenceValue {
15451547
// If we are currently interpreting a function that was declared with mapped entitlement access, any appearances
15461548
// of that mapped access in the body of the function should be replaced with the computed output of the map
15471549
auth := ConvertSemaAccessToStaticAuthorization(interpreter, typ.Authorization)
15481550

1549-
locationRange := LocationRange{
1550-
Location: interpreter.Location,
1551-
HasPosition: hasPosition,
1552-
}
1553-
15541551
return NewEphemeralReferenceValue(
15551552
interpreter,
15561553
auth,

interpreter/member_test.go

+3-10
Original file line numberDiff line numberDiff line change
@@ -696,16 +696,9 @@ func TestInterpretMemberAccess(t *testing.T) {
696696
}
697697
`)
698698

699-
// Currently a runtime error
700-
value, err := inter.Invoke("test")
701-
require.NoError(t, err)
702-
703-
AssertValuesEqual(
704-
t,
705-
inter,
706-
interpreter.Nil,
707-
value,
708-
)
699+
_, err := inter.Invoke("test")
700+
RequireError(t, err)
701+
require.ErrorAs(t, err, &interpreter.NonOptionalReferenceToNilError{})
709702
})
710703

711704
t.Run("composite reference, primitive field", func(t *testing.T) {

interpreter/reference_test.go

+5-15
Original file line numberDiff line numberDiff line change
@@ -1675,14 +1675,14 @@ func TestInterpretInvalidReferenceToOptionalConfusion(t *testing.T) {
16751675
_, err := inter.Invoke("main")
16761676
RequireError(t, err)
16771677

1678-
require.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{})
1678+
require.ErrorAs(t, err, &interpreter.NonOptionalReferenceToNilError{})
16791679
}
16801680

16811681
func TestInterpretReferenceToOptional(t *testing.T) {
16821682

16831683
t.Parallel()
16841684

1685-
t.Run("nil in anystruct", func(t *testing.T) {
1685+
t.Run("nil in AnyStruct", func(t *testing.T) {
16861686
t.Parallel()
16871687

16881688
inter := parseCheckAndInterpret(t, `
@@ -1692,19 +1692,9 @@ func TestInterpretReferenceToOptional(t *testing.T) {
16921692
}
16931693
`)
16941694

1695-
value, err := inter.Invoke("main")
1696-
require.NoError(t, err)
1697-
1698-
AssertValuesEqual(
1699-
t,
1700-
inter,
1701-
&interpreter.EphemeralReferenceValue{
1702-
Value: interpreter.Nil,
1703-
BorrowedType: sema.AnyStructType,
1704-
Authorization: interpreter.UnauthorizedAccess,
1705-
},
1706-
value,
1707-
)
1695+
_, err := inter.Invoke("main")
1696+
RequireError(t, err)
1697+
require.ErrorAs(t, err, &interpreter.NonOptionalReferenceToNilError{})
17081698
})
17091699

17101700
t.Run("nil in optional", func(t *testing.T) {

sema/check_member_expression.go

+15-16
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525

2626
// NOTE: only called if the member expression is *not* an assignment
2727
func (checker *Checker) VisitMemberExpression(expression *ast.MemberExpression) Type {
28-
accessedType, memberType, member, isOptional := checker.visitMember(expression, false)
28+
accessedType, memberType, member, _ := checker.visitMember(expression, false)
2929

3030
if !accessedType.IsInvalidType() {
3131
memberAccessType := accessedType
@@ -83,14 +83,6 @@ func (checker *Checker) VisitMemberExpression(expression *ast.MemberExpression)
8383

8484
checker.checkResourceMemberCapturingInFunction(expression, member, memberType)
8585

86-
// If the member access is optional chaining, only wrap the result value
87-
// in an optional, if it is not already an optional value
88-
if isOptional {
89-
if _, ok := memberType.(*OptionalType); !ok {
90-
memberType = NewOptionalType(checker.memoryGauge, memberType)
91-
}
92-
}
93-
9486
return memberType
9587
}
9688

@@ -203,7 +195,7 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
203195
// i.e. a Go type switch would be sufficient.
204196
// However, for some types (e.g. reference types) this depends on what type is referenced
205197

206-
getMemberForType := func(expressionType Type) {
198+
findAndSetResultingType := func(expressionType Type, optional bool) {
207199
resolver, ok := expressionType.GetMembers()[identifier]
208200
if !ok {
209201
return
@@ -216,6 +208,16 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
216208
checker.report,
217209
)
218210
resultingType = member.TypeAnnotation.Type
211+
212+
// If the member is accessed using optional-chaining, then the resulting type also should be optional.
213+
// However, if the member is already optional, then no need to double-wrap from optionals.
214+
if optional {
215+
if _, memberIsOptional := resultingType.(*OptionalType); !memberIsOptional {
216+
resultingType = NewOptionalType(checker.memoryGauge, resultingType)
217+
}
218+
}
219+
220+
isOptional = optional
219221
}
220222

221223
// Get the member from the accessed value based
@@ -229,8 +231,7 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
229231
if optionalExpressionType, ok := accessedType.(*OptionalType); ok {
230232
// The accessed type is optional, get the member from the wrapped type
231233

232-
getMemberForType(optionalExpressionType.Type)
233-
isOptional = true
234+
findAndSetResultingType(optionalExpressionType.Type, true)
234235
} else {
235236
// Optional chaining was used on a non-optional type, report an error
236237

@@ -247,14 +248,12 @@ func (checker *Checker) visitMember(expression *ast.MemberExpression, isAssignme
247248
// to avoid spurious error that member does not exist,
248249
// even if the non-optional accessed type has the member
249250

250-
getMemberForType(accessedType)
251+
findAndSetResultingType(accessedType, false)
251252
}
252253
} else {
253254
// The member is accessed directly without optional chaining.
254255
// Get the member directly from the accessed type
255-
256-
getMemberForType(accessedType)
257-
isOptional = false
256+
findAndSetResultingType(accessedType, false)
258257
}
259258

260259
if member == nil {

sema/check_reference_expression.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ func (checker *Checker) expectedTypeForReferencedExpr(
116116
checker.expectedTypeForReferencedExpr(expectedType.Type, hasPosition)
117117

118118
// Re-wrap with an optional
119-
expectedLeftType = &OptionalType{Type: expectedLeftType}
120-
returnType = &OptionalType{Type: returnType}
119+
expectedLeftType = NewOptionalType(checker.memoryGauge, expectedLeftType)
120+
returnType = NewOptionalType(checker.memoryGauge, returnType)
121121

122122
case *ReferenceType:
123123
referencedType := expectedType.Type

0 commit comments

Comments
 (0)