diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 0ee45e1d30..47ca8e7d64 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -20169,6 +20169,39 @@ func (v *DictionaryValue) Insert( return NilOptionalValue } + // At this point, existingValueStorable is not nil, which means previous op updated existing + // dictionary element (instead of inserting new element). + + // When existing dictionary element is updated, atree.OrderedMap reuses existing stored key + // so new key isn't stored or referenced in atree.OrderedMap. This aspect of atree cannot change + // without API changes in atree to return existing key storable for updated element. + + // Given this, remove the transferred key used to update existing dictionary element to + // prevent transferred key (in owner address) remaining in storage when it isn't + // referenced from dictionary. + + // Remove content of transferred keyValue. + keyValue.DeepRemove(interpreter, true) + + // Remove slab containing transferred keyValue from storage if needed. + // Currently, we only need to handle enum composite type because it is the only type that: + // - can be used as dictionary key (hashable) and + // - is transferred to its own slab. + if keyComposite, ok := keyValue.(*CompositeValue); ok { + + // Get SlabID of transferred enum value. + keyCompositeSlabID := keyComposite.SlabID() + + if keyCompositeSlabID == atree.SlabIDUndefined { + // It isn't possible for transferred enum value to be inlined in another container + // (SlabID as SlabIDUndefined) because it is transferred from stack by itself. + panic(errors.NewUnexpectedError("transferred enum value as dictionary key should not be inlined")) + } + + // Remove slab containing transferred enum value from storage. + interpreter.RemoveReferencedSlab(atree.SlabIDStorable(keyCompositeSlabID)) + } + storage := interpreter.Storage() existingValue := StoredValue( diff --git a/runtime/interpreter/value_test.go b/runtime/interpreter/value_test.go index 6a43cb3a6c..f0e26f1f11 100644 --- a/runtime/interpreter/value_test.go +++ b/runtime/interpreter/value_test.go @@ -4480,3 +4480,196 @@ func TestStringIsGraphemeBoundaryEnd(t *testing.T) { test(flagESflagEE, 16, true) } + +func TestOverwriteDictionaryValueWhereKeyIsStoredInSeparateAtreeSlab(t *testing.T) { + + t.Parallel() + + owner := common.Address{0x1} + + t.Run("enum as dict key", func(t *testing.T) { + + newEnumValue := func(inter *Interpreter) Value { + return NewCompositeValue( + inter, + EmptyLocationRange, + utils.TestLocation, + "Test", + common.CompositeKindEnum, + []CompositeField{ + { + Name: "rawValue", + Value: NewUnmeteredUInt8Value(42), + }, + }, + common.ZeroAddress, + ) + } + + storage := newUnmeteredInMemoryStorage() + + elaboration := sema.NewElaboration(nil) + elaboration.SetCompositeType( + testCompositeValueType.ID(), + testCompositeValueType, + ) + + inter, err := NewInterpreter( + &Program{ + Elaboration: elaboration, + }, + utils.TestLocation, + &Config{ + Storage: storage, + AtreeValueValidationEnabled: true, + AtreeStorageValidationEnabled: true, + }, + ) + require.NoError(t, err) + + // Create empty dictionary + dictionary := NewDictionaryValueWithAddress( + inter, + EmptyLocationRange, + &DictionaryStaticType{ + KeyType: PrimitiveStaticTypeAnyStruct, + ValueType: PrimitiveStaticTypeAnyStruct, + }, + owner, + ) + require.Equal(t, 0, dictionary.Count()) + + // Insert new key-value pair (enum as key) to dictionary + existingValue := dictionary.Insert( + inter, + EmptyLocationRange, + newEnumValue(inter), + NewUnmeteredInt64Value(int64(1)), + ) + require.Equal(t, NilOptionalValue, existingValue) + require.Equal(t, 1, dictionary.Count()) + + // Test inserted dictionary element + v, found := dictionary.Get( + inter, + EmptyLocationRange, + newEnumValue(inter), + ) + require.True(t, found) + require.Equal(t, Int64Value(1), v) + + // Update existing key with new value + existingValue = dictionary.Insert( + inter, + EmptyLocationRange, + newEnumValue(inter), + NewUnmeteredInt64Value(int64(2)), + ) + require.NotEqual(t, Int64Value(1), existingValue) + require.Equal(t, 1, dictionary.Count()) + + // Check updated dictionary element + v, found = dictionary.Get( + inter, + EmptyLocationRange, + newEnumValue(inter), + ) + require.True(t, found) + require.Equal(t, Int64Value(2), v) + + // Check storage containing only one root slab (dictionary root) + checkRootSlabIDsInStorage(t, storage, []atree.SlabID{dictionary.SlabID()}) + }) + + t.Run("large string as dict key", func(t *testing.T) { + newStringValue := func() Value { + return NewUnmeteredStringValue(strings.Repeat("a", 1024)) + } + + storage := newUnmeteredInMemoryStorage() + + elaboration := sema.NewElaboration(nil) + + inter, err := NewInterpreter( + &Program{ + Elaboration: elaboration, + }, + utils.TestLocation, + &Config{ + Storage: storage, + AtreeValueValidationEnabled: true, + AtreeStorageValidationEnabled: true, + }, + ) + require.NoError(t, err) + + // Create empty dictionary + dictionary := NewDictionaryValueWithAddress( + inter, + EmptyLocationRange, + &DictionaryStaticType{ + KeyType: PrimitiveStaticTypeAnyStruct, + ValueType: PrimitiveStaticTypeAnyStruct, + }, + owner, + ) + require.Equal(t, 0, dictionary.Count()) + + // Insert new key-value pair to dictionary + // Key is a large string which is stored in its own slab. + existingValue := dictionary.Insert( + inter, + EmptyLocationRange, + newStringValue(), + NewUnmeteredInt64Value(int64(1)), + ) + require.Equal(t, NilOptionalValue, existingValue) + require.Equal(t, 1, dictionary.Count()) + + // Check new dictionary element + v, found := dictionary.Get( + inter, + EmptyLocationRange, + newStringValue(), + ) + require.True(t, found) + require.Equal(t, Int64Value(1), v) + + // Update existing key with new value + existingValue = dictionary.Insert( + inter, + EmptyLocationRange, + newStringValue(), + NewUnmeteredInt64Value(int64(2)), + ) + require.NotEqual(t, Int64Value(1), existingValue) + require.Equal(t, 1, dictionary.Count()) + + // Check updated dictionary element + v, found = dictionary.Get( + inter, + EmptyLocationRange, + newStringValue(), + ) + require.True(t, found) + require.Equal(t, Int64Value(2), v) + + // Check storage containing only one root slab (dictionary root) + checkRootSlabIDsInStorage(t, storage, []atree.SlabID{dictionary.SlabID()}) + }) +} + +func checkRootSlabIDsInStorage(t *testing.T, storage atree.SlabStorage, expectedRootSlabIDs []atree.SlabID) { + rootSlabIDs, err := atree.CheckStorageHealth(storage, -1) + require.NoError(t, err) + + // Get non-temp address slab IDs from rootSlabIDs + nontempSlabIDs := make([]atree.SlabID, 0, len(rootSlabIDs)) + for rootSlabID := range rootSlabIDs { + if !rootSlabID.HasTempAddress() { + nontempSlabIDs = append(nontempSlabIDs, rootSlabID) + } + } + + require.ElementsMatch(t, expectedRootSlabIDs, nontempSlabIDs) +} diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 7f30aed2e9..7005274fed 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -11324,3 +11324,159 @@ func TestRuntimeForbidPublicEntitlementPublish(t *testing.T) { require.ErrorAs(t, err, &interpreter.EntitledCapabilityPublishingError{}) }) } + +func TestRuntimeStorageEnumAsDictionaryKey(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + address := common.MustBytesToAddress([]byte{0x1}) + + accountCodes := map[common.AddressLocation][]byte{} + var events []cadence.Event + var loggedMessages []string + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]common.Address, error) { + return []common.Address{address}, nil + }, + OnResolveLocation: NewSingleIdentifierLocationResolver(t), + OnUpdateAccountContractCode: func(location common.AddressLocation, code []byte) error { + accountCodes[location] = code + return nil + }, + OnGetAccountContractCode: func(location common.AddressLocation) (code []byte, err error) { + code = accountCodes[location] + return code, nil + }, + OnEmitEvent: func(event cadence.Event) error { + events = append(events, event) + return nil + }, + OnProgramLog: func(message string) { + loggedMessages = append(loggedMessages, message) + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + // Deploy contract + + err := runtime.ExecuteTransaction( + Script{ + Source: DeploymentTransaction( + "C", + []byte(` + access(all) contract C { + access(self) let counter: {E: UInt64} + access(all) enum E: UInt8 { + access(all) case A + access(all) case B + } + access(all) resource R { + access(all) let id: UInt64 + access(all) let e: E + init(id: UInt64, e: E) { + self.id = id + self.e = e + let counter = C.counter[e] ?? panic("couldn't retrieve resource counter") + // e is transferred and is stored in a slab which isn't removed after C.counter is updated. + C.counter[e] = counter + 1 + } + } + access(all) fun createR(id: UInt64, e: E): @R { + return <- create R(id: id, e: e) + } + access(all) resource Collection { + access(all) var rs: @{UInt64: R} + init () { + self.rs <- {} + } + access(all) fun withdraw(id: UInt64): @R { + return <- self.rs.remove(key: id)! + } + access(all) fun deposit(_ r: @R) { + let counts: {E: UInt64} = {} + log(r.e) + counts[r.e] = 42 // test indexing expression is transferred properly + log(r.e) + let oldR <- self.rs[r.id] <-! r + destroy oldR + } + } + access(all) fun createEmptyCollection(): @Collection { + return <- create Collection() + } + init() { + self.counter = { + E.A: 0, + E.B: 0 + } + } + } + `), + ), + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Store enum case + + err = runtime.ExecuteTransaction( + Script{ + Source: []byte(` + import C from 0x1 + transaction { + prepare(signer: auth(Storage) &Account) { + signer.storage.save(<-C.createEmptyCollection(), to: /storage/collection) + let collection = signer.storage.borrow<&C.Collection>(from: /storage/collection)! + collection.deposit(<-C.createR(id: 0, e: C.E.B)) + } + } + `), + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + // Load enum case + + err = runtime.ExecuteTransaction( + Script{ + Source: []byte(` + import C from 0x1 + transaction { + prepare(signer: auth(Storage) &Account) { + let collection = signer.storage.borrow<&C.Collection>(from: /storage/collection)! + let r <- collection.withdraw(id: 0) + log(r.e) + destroy r + } + } + `), + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + require.Equal(t, + []string{ + "A.0000000000000001.C.E(rawValue: 1)", + "A.0000000000000001.C.E(rawValue: 1)", + "A.0000000000000001.C.E(rawValue: 1)", + }, + loggedMessages, + ) +}