Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvements for EnumerableSet #2151

Merged
merged 7 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* `Address`: removed `toPayable`, use `payable(address)` instead. ([#2133](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2133))
* `ERC777`: `_send`, `_mint` and `_burn` now use the caller as the operator. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
* `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
* `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151))

## 2.5.0 (2020-02-04)

Expand Down
2 changes: 1 addition & 1 deletion contracts/access/AccessControl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ abstract contract AccessControl is Context {
* for more information.
*/
function getRoleMember(bytes32 role, uint256 index) public view returns (address) {
return _roles[role].members.get(index);
return _roles[role].members.at(index);
}

/**
Expand Down
4 changes: 2 additions & 2 deletions contracts/mocks/EnumerableSetMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ contract EnumerableSetMock{
return set.length();
}

function get(uint256 index) public view returns (address) {
nventuro marked this conversation as resolved.
Show resolved Hide resolved
return set.get(index);
function at(uint256 index) public view returns (address) {
return set.at(index);
}
}
94 changes: 50 additions & 44 deletions contracts/utils/EnumerableSet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,60 @@ pragma solidity ^0.6.0;
library EnumerableSet {

struct AddressSet {
// Position of the value in the `values` array, plus 1 because index 0
// means a value is not in the set.
mapping (address => uint256) index;
address[] values;
address[] keys;
nventuro marked this conversation as resolved.
Show resolved Hide resolved
// Position of the key in the `keys` array, plus 1 because index 0
// means a key is not in the set.
mapping (address => uint256) indexes;
nventuro marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* @dev Add a value to a set. O(1).
* Returns false if the value was already in the set.
* @dev Add a key to a set. O(1).
*
* Returns false if the key was already in the set.
*/
function add(AddressSet storage set, address value)
function add(AddressSet storage set, address key)
internal
returns (bool)
{
if (!contains(set, value)) {
set.values.push(value);
// The element is stored at length-1, but we add 1 to all indexes
if (!contains(set, key)) {
set.keys.push(key);
// The key is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
set.index[value] = set.values.length;
set.indexes[key] = set.keys.length;
return true;
} else {
return false;
}
}

/**
* @dev Removes a value from a set. O(1).
* Returns false if the value was not present in the set.
* @dev Removes a key from a set. O(1).
*
* Returns false if the key was not present in the set.
*/
function remove(AddressSet storage set, address value)
function remove(AddressSet storage set, address key)
internal
returns (bool)
{
if (contains(set, value)){
uint256 toDeleteIndex = set.index[value] - 1;
uint256 lastIndex = set.values.length - 1;
if (contains(set, key)){
uint256 toDeleteIndex = set.indexes[key] - 1;
uint256 lastIndex = set.keys.length - 1;

// If the element we're deleting is the last one, we can just remove it without doing a swap
// If the key we're deleting is the last one, we can just remove it without doing a swap
if (lastIndex != toDeleteIndex) {
address lastValue = set.values[lastIndex];
address lastKey = set.keys[lastIndex];

// Move the last value to the index where the deleted value is
set.values[toDeleteIndex] = lastValue;
// Update the index for the moved value
set.index[lastValue] = toDeleteIndex + 1; // All indexes are 1-based
// Move the last key to the index where the deleted key is
set.keys[toDeleteIndex] = lastKey;
// Update the index for the moved key
set.indexes[lastKey] = toDeleteIndex + 1; // All indexes are 1-based
}

// Delete the index entry for the deleted value
delete set.index[value];
// Delete the slot where the moved key was stored
set.keys.pop();

// Delete the old entry for the moved value
set.values.pop();
// Delete the index for the deleted slot
delete set.indexes[key];

return true;
} else {
Expand All @@ -80,60 +82,64 @@ library EnumerableSet {
}

/**
* @dev Returns true if the value is in the set. O(1).
* @dev Returns true if the key is in the set. O(1).
*/
function contains(AddressSet storage set, address value)
function contains(AddressSet storage set, address key)
internal
view
returns (bool)
{
return set.index[value] != 0;
return set.indexes[key] != 0;
}

/**
* @dev Returns an array with all values in the set. O(N).
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
* @dev Returns an array with all keys in the set. O(N).
*
* Note that there are no guarantees on the ordering of keys inside the
* array, and it may change when more keys are added or removed.

* WARNING: This function may run out of gas on large sets: use {length} and
* {get} instead in these cases.
* {at} instead in these cases.
*/
function enumerate(AddressSet storage set)
internal
view
returns (address[] memory)
{
address[] memory output = new address[](set.values.length);
for (uint256 i; i < set.values.length; i++){
output[i] = set.values[i];
address[] memory output = new address[](set.keys.length);
for (uint256 i; i < set.keys.length; i++){
output[i] = set.keys[i];
}
return output;
}

/**
* @dev Returns the number of elements on the set. O(1).
* @dev Returns the number of keys on the set. O(1).
*/
function length(AddressSet storage set)
internal
view
returns (uint256)
{
return set.values.length;
return set.keys.length;
}

/** @dev Returns the element stored at position `index` in the set. O(1).
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
/**
* @dev Returns the key stored at position `index` in the set. O(1).
*
* Note that there are no guarantees on the ordering of keys inside the
* array, and it may change when more keys are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function get(AddressSet storage set, uint256 index)
function at(AddressSet storage set, uint256 index)
internal
view
returns (address)
{
return set.values[index];
require(set.keys.length > index, "EnumerableSet: index out of bounds");
return set.keys[index];
}
}
8 changes: 6 additions & 2 deletions test/utils/EnumerableSet.test.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const { accounts, contract } = require('@openzeppelin/test-environment');
const { expectEvent } = require('@openzeppelin/test-helpers');
const { expectEvent, expectRevert } = require('@openzeppelin/test-helpers');
const { expect } = require('chai');

const EnumerableSetMock = contract.fromArtifact('EnumerableSetMock');
Expand All @@ -21,7 +21,7 @@ describe('EnumerableSet', function () {
expect(await set.length()).to.bignumber.equal(members.length.toString());

expect(await Promise.all([...Array(members.length).keys()].map(index =>
set.get(index)
set.at(index)
))).to.have.same.members(members);
}

Expand Down Expand Up @@ -55,6 +55,10 @@ describe('EnumerableSet', function () {
await expectMembersMatch(this.set, [accountA]);
});

it('reverts when retrieving non-existent elements', async function () {
await expectRevert(this.set.at(0), 'EnumerableSet: index out of bounds');
});

it('removes added values', async function () {
await this.set.add(accountA);

Expand Down