Skip to content

Commit

Permalink
Deduplicate logic in Votes.sol (OpenZeppelin#5314)
Browse files Browse the repository at this point in the history
Co-authored-by: Arr00 <13561405+arr00@users.noreply.github.com>
  • Loading branch information
Amxx and arr00 authored Nov 25, 2024
1 parent 2562c11 commit c3cb7a0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
21 changes: 11 additions & 10 deletions contracts/governance/utils/Votes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ abstract contract Votes is Context, EIP712, Nonces, IERC5805 {
return "mode=blocknumber&from=default";
}

/**
* @dev Validate that a timepoint is in the past, and return it as a uint48.
*/
function _validateTimepoint(uint256 timepoint) internal view returns (uint48) {
uint48 currentTimepoint = clock();
if (timepoint >= currentTimepoint) revert ERC5805FutureLookup(timepoint, currentTimepoint);
return SafeCast.toUint48(timepoint);
}

/**
* @dev Returns the current amount of votes that `account` has.
*/
Expand All @@ -87,11 +96,7 @@ abstract contract Votes is Context, EIP712, Nonces, IERC5805 {
* - `timepoint` must be in the past. If operating using block numbers, the block must be already mined.
*/
function getPastVotes(address account, uint256 timepoint) public view virtual returns (uint256) {
uint48 currentTimepoint = clock();
if (timepoint >= currentTimepoint) {
revert ERC5805FutureLookup(timepoint, currentTimepoint);
}
return _delegateCheckpoints[account].upperLookupRecent(SafeCast.toUint48(timepoint));
return _delegateCheckpoints[account].upperLookupRecent(_validateTimepoint(timepoint));
}

/**
Expand All @@ -107,11 +112,7 @@ abstract contract Votes is Context, EIP712, Nonces, IERC5805 {
* - `timepoint` must be in the past. If operating using block numbers, the block must be already mined.
*/
function getPastTotalSupply(uint256 timepoint) public view virtual returns (uint256) {
uint48 currentTimepoint = clock();
if (timepoint >= currentTimepoint) {
revert ERC5805FutureLookup(timepoint, currentTimepoint);
}
return _totalCheckpoints.upperLookupRecent(SafeCast.toUint48(timepoint));
return _totalCheckpoints.upperLookupRecent(_validateTimepoint(timepoint));
}

/**
Expand Down
17 changes: 4 additions & 13 deletions contracts/governance/utils/VotesExtended.sol
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import {SafeCast} from "../../utils/math/SafeCast.sol";
* {ERC20Votes} and {ERC721Votes} follow this pattern and are thus safe to use with {VotesExtended}.
*/
abstract contract VotesExtended is Votes {
using SafeCast for uint256;
using Checkpoints for Checkpoints.Trace160;
using Checkpoints for Checkpoints.Trace208;

Expand All @@ -47,11 +46,7 @@ abstract contract VotesExtended is Votes {
* - `timepoint` must be in the past. If operating using block numbers, the block must be already mined.
*/
function getPastDelegate(address account, uint256 timepoint) public view virtual returns (address) {
uint48 currentTimepoint = clock();
if (timepoint >= currentTimepoint) {
revert ERC5805FutureLookup(timepoint, currentTimepoint);
}
return address(_delegateCheckpoints[account].upperLookupRecent(timepoint.toUint48()));
return address(_delegateCheckpoints[account].upperLookupRecent(_validateTimepoint(timepoint)));
}

/**
Expand All @@ -63,11 +58,7 @@ abstract contract VotesExtended is Votes {
* - `timepoint` must be in the past. If operating using block numbers, the block must be already mined.
*/
function getPastBalanceOf(address account, uint256 timepoint) public view virtual returns (uint256) {
uint48 currentTimepoint = clock();
if (timepoint >= currentTimepoint) {
revert ERC5805FutureLookup(timepoint, currentTimepoint);
}
return _balanceOfCheckpoints[account].upperLookupRecent(timepoint.toUint48());
return _balanceOfCheckpoints[account].upperLookupRecent(_validateTimepoint(timepoint));
}

/// @inheritdoc Votes
Expand All @@ -82,10 +73,10 @@ abstract contract VotesExtended is Votes {
super._transferVotingUnits(from, to, amount);
if (from != to) {
if (from != address(0)) {
_balanceOfCheckpoints[from].push(clock(), _getVotingUnits(from).toUint208());
_balanceOfCheckpoints[from].push(clock(), SafeCast.toUint208(_getVotingUnits(from)));
}
if (to != address(0)) {
_balanceOfCheckpoints[to].push(clock(), _getVotingUnits(to).toUint208());
_balanceOfCheckpoints[to].push(clock(), SafeCast.toUint208(_getVotingUnits(to)));
}
}
}
Expand Down

0 comments on commit c3cb7a0

Please sign in to comment.