forked from OpenZeppelin/openzeppelin-contracts
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request OpenZeppelin#171 from Recmo/feature/reentrancy-guard
Add ReentrancyGuard
- Loading branch information
Showing
5 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
pragma solidity ^0.4.8; | ||
|
||
/// @title Helps contracts guard agains rentrancy attacks. | ||
/// @author Remco Bloemen <remco@2π.com> | ||
/// @notice If you mark a function `nonReentrant`, you should also | ||
/// mark it `external`. | ||
contract ReentrancyGuard { | ||
|
||
/// @dev We use a single lock for the whole contract. | ||
bool private rentrancy_lock = false; | ||
|
||
/// Prevent contract from calling itself, directly or indirectly. | ||
/// @notice If you mark a function `nonReentrant`, you should also | ||
/// mark it `external`. Calling one nonReentrant function from | ||
/// another is not supported. Instead, you can implement a | ||
/// `private` function doing the actual work, and a `external` | ||
/// wrapper marked as `nonReentrant`. | ||
modifier nonReentrant() { | ||
if(rentrancy_lock == false) { | ||
rentrancy_lock = true; | ||
_; | ||
rentrancy_lock = false; | ||
} else { | ||
throw; | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
'use strict'; | ||
import expectThrow from './helpers/expectThrow'; | ||
const ReentrancyMock = artifacts.require('./helper/ReentrancyMock.sol'); | ||
const ReentrancyAttack = artifacts.require('./helper/ReentrancyAttack.sol'); | ||
|
||
contract('ReentrancyGuard', function(accounts) { | ||
let reentrancyMock; | ||
|
||
beforeEach(async function() { | ||
reentrancyMock = await ReentrancyMock.new(); | ||
let initialCounter = await reentrancyMock.counter(); | ||
assert.equal(initialCounter, 0); | ||
}); | ||
|
||
it('should not allow remote callback', async function() { | ||
let attacker = await ReentrancyAttack.new(); | ||
await expectThrow(reentrancyMock.countAndCall(attacker.address)); | ||
}); | ||
|
||
// The following are more side-effects that intended behaviour: | ||
// I put them here as documentation, and to monitor any changes | ||
// in the side-effects. | ||
|
||
it('should not allow local recursion', async function() { | ||
await expectThrow(reentrancyMock.countLocalRecursive(10)); | ||
}); | ||
|
||
it('should not allow indirect local recursion', async function() { | ||
await expectThrow(reentrancyMock.countThisRecursive(10)); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
pragma solidity ^0.4.8; | ||
|
||
contract ReentrancyAttack { | ||
|
||
function callSender(bytes4 data) { | ||
if(!msg.sender.call(data)) { | ||
throw; | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
pragma solidity ^0.4.8; | ||
|
||
import '../../contracts/ReentrancyGuard.sol'; | ||
import './ReentrancyAttack.sol'; | ||
|
||
contract ReentrancyMock is ReentrancyGuard { | ||
|
||
uint256 public counter; | ||
|
||
function ReentrancyMock() { | ||
counter = 0; | ||
} | ||
|
||
function count() private { | ||
counter += 1; | ||
} | ||
|
||
function countLocalRecursive(uint n) public nonReentrant { | ||
if(n > 0) { | ||
count(); | ||
countLocalRecursive(n - 1); | ||
} | ||
} | ||
|
||
function countThisRecursive(uint256 n) public nonReentrant { | ||
bytes4 func = bytes4(keccak256("countThisRecursive(uint256)")); | ||
if(n > 0) { | ||
count(); | ||
bool result = this.call(func, n - 1); | ||
if(result != true) { | ||
throw; | ||
} | ||
} | ||
} | ||
|
||
function countAndCall(ReentrancyAttack attacker) public nonReentrant { | ||
count(); | ||
bytes4 func = bytes4(keccak256("callback()")); | ||
attacker.callSender(func); | ||
} | ||
|
||
function callback() external nonReentrant { | ||
count(); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
export default async promise => { | ||
try { | ||
await promise; | ||
} catch (error) { | ||
// TODO: Check jump destination to destinguish between a throw | ||
// and an actual invalid jump. | ||
const invalidJump = error.message.search('invalid JUMP') >= 0; | ||
// TODO: When we contract A calls contract B, and B throws, instead | ||
// of an 'invalid jump', we get an 'out of gas' error. How do | ||
// we distinguish this from an actual out of gas event? (The | ||
// testrpc log actually show an 'invalid jump' event.) | ||
const outOfGas = error.message.search('out of gas') >= 0; | ||
assert( | ||
invalidJump || outOfGas, | ||
"Expected throw, got '" + error + "' instead", | ||
); | ||
return; | ||
} | ||
assert.fail('Expected throw not received'); | ||
}; |