From 349b334ef92b9be20c66fdd8270bf077b48a0bdf Mon Sep 17 00:00:00 2001 From: Austin Adams <36116882+aadams@users.noreply.github.com> Date: Thu, 7 Dec 2023 17:22:14 -0500 Subject: [PATCH] added delta overflow checks (#433) --- test/Pool.t.sol | 16 ++++ test/utils/LiquidityAmounts.sol | 134 ++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 test/utils/LiquidityAmounts.sol diff --git a/test/Pool.t.sol b/test/Pool.t.sol index 76179b089..a4814c343 100644 --- a/test/Pool.t.sol +++ b/test/Pool.t.sol @@ -8,6 +8,9 @@ import {PoolManager} from "../src/PoolManager.sol"; import {Position} from "../src/libraries/Position.sol"; import {TickMath} from "../src/libraries/TickMath.sol"; import {TickBitmap} from "../src/libraries/TickBitmap.sol"; +import {LiquidityAmounts} from "./utils/LiquidityAmounts.sol"; +import {Constants} from "./utils/Constants.sol"; +import {SafeCast} from "../src/libraries/SafeCast.sol"; contract PoolTest is Test { using Pool for Pool.State; @@ -67,6 +70,19 @@ contract PoolTest is Test { vm.expectRevert( abi.encodeWithSelector(TickBitmap.TickMisaligned.selector, params.tickUpper, params.tickSpacing) ); + } else { + // We need the assumptions above to calculate this + uint256 maxInt128InTypeU256 = uint256(uint128(Constants.MAX_UINT128)); + (uint256 amount0, uint256 amount1) = LiquidityAmounts.getAmountsForLiquidity( + sqrtPriceX96, + TickMath.getSqrtRatioAtTick(params.tickLower), + TickMath.getSqrtRatioAtTick(params.tickUpper), + uint128(params.liquidityDelta) + ); + + if ((amount0 > maxInt128InTypeU256) || (amount1 > maxInt128InTypeU256)) { + vm.expectRevert(abi.encodeWithSelector(SafeCast.SafeCastOverflow.selector)); + } } params.owner = address(this); diff --git a/test/utils/LiquidityAmounts.sol b/test/utils/LiquidityAmounts.sol new file mode 100644 index 000000000..6767e692c --- /dev/null +++ b/test/utils/LiquidityAmounts.sol @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: UNLICENSED +pragma solidity ^0.8.20; + +import "../../src/libraries/FullMath.sol"; +import "../../src/libraries/FixedPoint96.sol"; + +/// @title Liquidity amount functions +/// @notice Provides functions for computing liquidity amounts from token amounts and prices +library LiquidityAmounts { + /// @notice Downcasts uint256 to uint128 + /// @param x The uint258 to be downcasted + /// @return y The passed value, downcasted to uint128 + function toUint128(uint256 x) private pure returns (uint128 y) { + require((y = uint128(x)) == x); + } + + /// @notice Computes the amount of liquidity received for a given amount of token0 and price range + /// @dev Calculates amount0 * (sqrt(upper) * sqrt(lower)) / (sqrt(upper) - sqrt(lower)) + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param amount0 The amount0 being sent in + /// @return liquidity The amount of returned liquidity + function getLiquidityForAmount0(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint256 amount0) + internal + pure + returns (uint128 liquidity) + { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + uint256 intermediate = FullMath.mulDiv(sqrtRatioAX96, sqrtRatioBX96, FixedPoint96.Q96); + return toUint128(FullMath.mulDiv(amount0, intermediate, sqrtRatioBX96 - sqrtRatioAX96)); + } + + /// @notice Computes the amount of liquidity received for a given amount of token1 and price range + /// @dev Calculates amount1 / (sqrt(upper) - sqrt(lower)). + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param amount1 The amount1 being sent in + /// @return liquidity The amount of returned liquidity + function getLiquidityForAmount1(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint256 amount1) + internal + pure + returns (uint128 liquidity) + { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + return toUint128(FullMath.mulDiv(amount1, FixedPoint96.Q96, sqrtRatioBX96 - sqrtRatioAX96)); + } + + /// @notice Computes the maximum amount of liquidity received for a given amount of token0, token1, the current + /// pool prices and the prices at the tick boundaries + /// @param sqrtRatioX96 A sqrt price representing the current pool prices + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param amount0 The amount of token0 being sent in + /// @param amount1 The amount of token1 being sent in + /// @return liquidity The maximum amount of liquidity received + function getLiquidityForAmounts( + uint160 sqrtRatioX96, + uint160 sqrtRatioAX96, + uint160 sqrtRatioBX96, + uint256 amount0, + uint256 amount1 + ) internal pure returns (uint128 liquidity) { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + + if (sqrtRatioX96 <= sqrtRatioAX96) { + liquidity = getLiquidityForAmount0(sqrtRatioAX96, sqrtRatioBX96, amount0); + } else if (sqrtRatioX96 < sqrtRatioBX96) { + uint128 liquidity0 = getLiquidityForAmount0(sqrtRatioX96, sqrtRatioBX96, amount0); + uint128 liquidity1 = getLiquidityForAmount1(sqrtRatioAX96, sqrtRatioX96, amount1); + + liquidity = liquidity0 < liquidity1 ? liquidity0 : liquidity1; + } else { + liquidity = getLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1); + } + } + + /// @notice Computes the amount of token0 for a given amount of liquidity and a price range + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param liquidity The liquidity being valued + /// @return amount0 The amount of token0 + function getAmount0ForLiquidity(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint128 liquidity) + internal + pure + returns (uint256 amount0) + { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + + return FullMath.mulDiv( + uint256(liquidity) << FixedPoint96.RESOLUTION, sqrtRatioBX96 - sqrtRatioAX96, sqrtRatioBX96 + ) / sqrtRatioAX96; + } + + /// @notice Computes the amount of token1 for a given amount of liquidity and a price range + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param liquidity The liquidity being valued + /// @return amount1 The amount of token1 + function getAmount1ForLiquidity(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint128 liquidity) + internal + pure + returns (uint256 amount1) + { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + + return FullMath.mulDiv(liquidity, sqrtRatioBX96 - sqrtRatioAX96, FixedPoint96.Q96); + } + + /// @notice Computes the token0 and token1 value for a given amount of liquidity, the current + /// pool prices and the prices at the tick boundaries + /// @param sqrtRatioX96 A sqrt price representing the current pool prices + /// @param sqrtRatioAX96 A sqrt price representing the first tick boundary + /// @param sqrtRatioBX96 A sqrt price representing the second tick boundary + /// @param liquidity The liquidity being valued + /// @return amount0 The amount of token0 + /// @return amount1 The amount of token1 + function getAmountsForLiquidity( + uint160 sqrtRatioX96, + uint160 sqrtRatioAX96, + uint160 sqrtRatioBX96, + uint128 liquidity + ) internal pure returns (uint256 amount0, uint256 amount1) { + if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96); + + if (sqrtRatioX96 <= sqrtRatioAX96) { + amount0 = getAmount0ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity); + } else if (sqrtRatioX96 < sqrtRatioBX96) { + amount0 = getAmount0ForLiquidity(sqrtRatioX96, sqrtRatioBX96, liquidity); + amount1 = getAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioX96, liquidity); + } else { + amount1 = getAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity); + } + } +}