Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvements for EnumerableSet #2151

Merged
merged 7 commits into from
Mar 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
* `Address`: removed `toPayable`, use `payable(address)` instead. ([#2133](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2133))
* `ERC777`: `_send`, `_mint` and `_burn` now use the caller as the operator. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
* `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
* `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151))
* `ERC165Checker`: functions no longer have a leading underscore. ([#2150](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2150))

## 2.5.0 (2020-02-04)
Expand Down
2 changes: 1 addition & 1 deletion contracts/access/AccessControl.sol
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ abstract contract AccessControl is Context {
* for more information.
*/
function getRoleMember(bytes32 role, uint256 index) public view returns (address) {
return _roles[role].members.get(index);
return _roles[role].members.at(index);
}

/**
Expand Down
6 changes: 3 additions & 3 deletions contracts/mocks/EnumerableSetMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pragma solidity ^0.6.0;

import "../utils/EnumerableSet.sol";

contract EnumerableSetMock{
contract EnumerableSetMock {
using EnumerableSet for EnumerableSet.AddressSet;

event TransactionResult(bool result);
Expand Down Expand Up @@ -31,7 +31,7 @@ contract EnumerableSetMock{
return _set.length();
}

function get(uint256 index) public view returns (address) {
nventuro marked this conversation as resolved.
Show resolved Hide resolved
return _set.get(index);
function at(uint256 index) public view returns (address) {
return _set.at(index);
}
}
56 changes: 31 additions & 25 deletions contracts/utils/EnumerableSet.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,26 @@ pragma solidity ^0.6.0;
library EnumerableSet {

struct AddressSet {
address[] _values;
// Position of the value in the `values` array, plus 1 because index 0
// means a value is not in the set.
mapping (address => uint256) index;
address[] values;
mapping (address => uint256) _indexes;
}

/**
* @dev Add a value to a set. O(1).
*
* Returns false if the value was already in the set.
*/
function add(AddressSet storage set, address value)
internal
returns (bool)
{
if (!contains(set, value)) {
set.values.push(value);
// The element is stored at length-1, but we add 1 to all indexes
set._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
set.index[value] = set.values.length;
set._indexes[value] = set._values.length;
return true;
} else {
return false;
Expand All @@ -47,31 +48,32 @@ library EnumerableSet {

/**
* @dev Removes a value from a set. O(1).
*
* Returns false if the value was not present in the set.
*/
function remove(AddressSet storage set, address value)
internal
returns (bool)
{
if (contains(set, value)){
uint256 toDeleteIndex = set.index[value] - 1;
uint256 lastIndex = set.values.length - 1;
uint256 toDeleteIndex = set._indexes[value] - 1;
uint256 lastIndex = set._values.length - 1;

// If the element we're deleting is the last one, we can just remove it without doing a swap
// If the value we're deleting is the last one, we can just remove it without doing a swap
if (lastIndex != toDeleteIndex) {
address lastValue = set.values[lastIndex];
address lastvalue = set._values[lastIndex];

// Move the last value to the index where the deleted value is
set.values[toDeleteIndex] = lastValue;
set._values[toDeleteIndex] = lastvalue;
// Update the index for the moved value
set.index[lastValue] = toDeleteIndex + 1; // All indexes are 1-based
set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based
}

// Delete the index entry for the deleted value
delete set.index[value];
// Delete the slot where the moved value was stored
set._values.pop();

// Delete the old entry for the moved value
set.values.pop();
// Delete the index for the deleted slot
delete set._indexes[value];

return true;
} else {
Expand All @@ -87,53 +89,57 @@ library EnumerableSet {
view
returns (bool)
{
return set.index[value] != 0;
return set._indexes[value] != 0;
}

/**
* @dev Returns an array with all values in the set. O(N).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.

* WARNING: This function may run out of gas on large sets: use {length} and
* {get} instead in these cases.
* {at} instead in these cases.
*/
function enumerate(AddressSet storage set)
internal
view
returns (address[] memory)
{
address[] memory output = new address[](set.values.length);
for (uint256 i; i < set.values.length; i++){
output[i] = set.values[i];
address[] memory output = new address[](set._values.length);
for (uint256 i; i < set._values.length; i++){
output[i] = set._values[i];
}
return output;
}

/**
* @dev Returns the number of elements on the set. O(1).
* @dev Returns the number of values on the set. O(1).
*/
function length(AddressSet storage set)
internal
view
returns (uint256)
{
return set.values.length;
return set._values.length;
}

/** @dev Returns the element stored at position `index` in the set. O(1).
/**
* @dev Returns the value stored at position `index` in the set. O(1).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function get(AddressSet storage set, uint256 index)
function at(AddressSet storage set, uint256 index)
internal
view
returns (address)
{
return set.values[index];
require(set._values.length > index, "EnumerableSet: index out of bounds");
return set._values[index];
}
}
8 changes: 6 additions & 2 deletions test/utils/EnumerableSet.test.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const { accounts, contract } = require('@openzeppelin/test-environment');
const { expectEvent } = require('@openzeppelin/test-helpers');
const { expectEvent, expectRevert } = require('@openzeppelin/test-helpers');
const { expect } = require('chai');

const EnumerableSetMock = contract.fromArtifact('EnumerableSetMock');
Expand All @@ -21,7 +21,7 @@ describe('EnumerableSet', function () {
expect(await set.length()).to.bignumber.equal(members.length.toString());

expect(await Promise.all([...Array(members.length).keys()].map(index =>
set.get(index)
set.at(index)
))).to.have.same.members(members);
}

Expand Down Expand Up @@ -55,6 +55,10 @@ describe('EnumerableSet', function () {
await expectMembersMatch(this.set, [accountA]);
});

it('reverts when retrieving non-existent elements', async function () {
await expectRevert(this.set.at(0), 'EnumerableSet: index out of bounds');
});

it('removes added values', async function () {
await this.set.add(accountA);

Expand Down