diff --git a/.gitignore b/.gitignore index 7a7b33a..94108b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ package-lock.json node_modules +artifacts +cache diff --git a/contracts/Controller.sol b/contracts/Controller.sol index 1181e6b..6bf1c70 100644 --- a/contracts/Controller.sol +++ b/contracts/Controller.sol @@ -1,12 +1,14 @@ // SPDX-License-Identifier: MIT pragma solidity 0.8.14; + import "@openzeppelin/contracts/token/ERC721/extensions/IERC721Enumerable.sol"; import "@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol"; import "@openzeppelin/contracts/token/ERC721/IERC721.sol"; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol"; +import "./MyNFT.sol"; import "./MyToken.sol"; /** @@ -14,9 +16,11 @@ import "./MyToken.sol"; */ contract Controller is IERC721Receiver { - address immutable internal deployer; - IERC721Enumerable private immutable stakedNFT; + address internal immutable deployer; + MyNFT private immutable stakedNFT; mapping(uint256 => address) public staker; // keep track of who owns what NFT to handle withdrawals + mapping(address => mapping(uint256 => uint256)) private globalIndexToTokenId; // (staker) -> (global list index) -> (token ID) + mapping(address => mapping(uint256 => uint256)) private tokenIdToGlobalIndex; // (staker) -> (token ID) -> (global list index) // track rewards MyToken private rewardToken; // the token you earn as reward @@ -32,10 +36,10 @@ contract Controller is IERC721Receiver { * @notice Construct the controller. * @param stakedNFTContractAddress The NFT collection address being staked for rewards. */ - constructor(IERC721Enumerable stakedNFTContractAddress, MyToken _rewardToken) { + constructor(MyNFT stakedNFTContractAddress, MyToken _rewardTokenAddress) { deployer = msg.sender; - stakedNFT = stakedNFTContractAddress; - rewardToken = _rewardToken; + stakedNFT = MyNFT(stakedNFTContractAddress); + rewardToken = MyToken(_rewardTokenAddress); } /** @@ -65,6 +69,10 @@ contract Controller is IERC721Receiver { numStaked[_staker] -= 1; delete stakedAtBlocktimestamp[tokenId]; delete numIntervalsCollected[tokenId]; + + uint index = tokenIdToGlobalIndex[_staker][tokenId]; + delete tokenIdToGlobalIndex[_staker][index]; + delete globalIndexToTokenId[_staker][tokenId]; } /** @@ -80,30 +88,16 @@ contract Controller is IERC721Receiver { stakedNFT.safeTransferFrom(address(this), msg.sender, tokenId); } - // TODO: THE CONTROLLER NEEDS TO KNOW ABOUT THE TOKEN ADDRESS AND OWN SOME AS WELL TO SEND AS REWARDS!!! - // TODO: SO MAKE SURE THAT THE CONTROLLER CAN MINT WHATEVER THEY WANT... JUST CALL MINT WITH RECEIVER ADDRESS ACTUALLY - /** * @notice Withdraw the rewards for this user given all their staked tokens. */ function withdrawReward() external { require(numStaked[msg.sender] > 0, "not staker in the contract!"); + uint _numStaked = numStaked[msg.sender]; - // how many NFTs are owned by the caller in the collection - uint bal = stakedNFT.balanceOf(msg.sender); - assert(bal > 0); - - uint numFound; uint totalNumIntervals; // across all staked NFTs of this user - for (uint i; i < bal; i++) { - // get this specific NFT owned by the caller: - uint thisTokenID = stakedNFT.tokenOfOwnerByIndex(msg.sender, i); - - // is the owner/caller staking this NFT in our contract rn? - if (staker[thisTokenID] == address(0)) - continue; // ..no - - numFound += 1; + for (uint i; i < _numStaked; i++) { + uint thisTokenID = globalIndexToTokenId[msg.sender][i]; // yes, so collect rewards, count the number of `rewardInterval` for this NFT uint256 thisNumIntervals = _compute_num_24hours_single_nft(thisTokenID); @@ -116,9 +110,6 @@ contract Controller is IERC721Receiver { // update the rewards collected for this token numIntervalsCollected[thisTokenID] += thisNumIntervals; - - if (numFound == numStaked[msg.sender]) - break; } uint totalReward = totalNumIntervals * rewardRate_per_Interval; @@ -137,9 +128,9 @@ contract Controller is IERC721Receiver { uint deltaTime = currTime - stakedAtBlocktimestamp[tokenId]; if (deltaTime <= 0) return 0; - uint num24hours = deltaTime % rewardInterval; // rewardInterval = 24 hours, e.g. + uint numIntervals = deltaTime / rewardInterval; // rewardInterval = 24 hours, e.g. - return num24hours; + return numIntervals; } /** @@ -154,9 +145,14 @@ contract Controller is IERC721Receiver { bytes calldata ) external returns (bytes4) { + // keep track of who owns which nfts staker[tokenId] = from; // log the owner of the NFT - stakedAtBlocktimestamp[tokenId] = block.timestamp; + globalIndexToTokenId[from][numStaked[from]] = tokenId; + tokenIdToGlobalIndex[from][tokenId] = numStaked[from]; // where is the tokenId in the global index, for this staker numStaked[from] += 1; + + // for rewards + stakedAtBlocktimestamp[tokenId] = block.timestamp; numIntervalsCollected[tokenId] = 0; return IERC721Receiver.onERC721Received.selector; diff --git a/package.json b/package.json index e07ab8c..f961eac 100644 --- a/package.json +++ b/package.json @@ -19,6 +19,6 @@ "devDependencies": { "@nomicfoundation/hardhat-toolbox": "^1.0.2", "@openzeppelin/contracts": "^4.7.3", - "hardhat": "^2.10.1" + "hardhat": "^2.10.2" } } diff --git a/scripts/deploy.js b/scripts/deploy.js new file mode 100644 index 0000000..103993b --- /dev/null +++ b/scripts/deploy.js @@ -0,0 +1,28 @@ +async function main() { + const [deployer] = await ethers.getSigners(); + + console.log("Deploying contracts with the account:", deployer.address); + console.log("Account balance:", (await deployer.getBalance()).toString()); + + // deploy the token + const MyToken = await ethers.getContractFactory("MyToken"); + const myToken = await MyToken.deploy(10); + console.log("MyToken Contract address:", myToken.address); + + // then the NFT + const MyNFT = await ethers.getContractFactory("MyNFT"); + const myNFT = await MyNFT.deploy(); + console.log("MyNFT Contract address:", myNFT.address); + + // now deploy the controller + const Controller = await ethers.getContractFactory("Controller"); + const controller = await Controller.deploy(myNFT.address, myToken.address); + console.log("Controller Contract address:", controller.address); +} + +main() +.then(() => process.exit(0)) +.catch((error) => { + console.error(error); + process.exit(1); +});