Skip to content

Commit

Permalink
✨ MinHeapLib k-smallest view function (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vectorized authored Jan 16, 2024
1 parent 03fe77a commit 5ddc201
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 18 deletions.
16 changes: 10 additions & 6 deletions .gas-snapshot
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,17 @@ MetadataReaderLibTest:testReadStringTruncated(uint256) (runs: 256, μ: 775902, ~
MetadataReaderLibTest:testReadUint() (gas: 542217)
MetadataReaderLibTest:testReadUint(uint256) (runs: 256, μ: 22765, ~: 23854)
MetadataReaderLibTest:test__codesize() (gas: 8645)
MinHeapLibTest:testHeapEnqueue(uint256) (runs: 256, μ: 182135, ~: 178082)
MinHeapLibTest:testHeapEnqueueGas(uint256) (runs: 256, μ: 293266, ~: 293389)
MinHeapLibTest:testHeapPushAndPop(uint256) (runs: 256, μ: 110123, ~: 99702)
MinHeapLibTest:testHeapPushPop(uint256) (runs: 256, μ: 249211, ~: 257370)
MinHeapLibTest:testHeapReplace(uint256) (runs: 256, μ: 310408, ~: 320169)
MinHeapLibTest:testHeapEnqueue(uint256) (runs: 256, μ: 186631, ~: 182615)
MinHeapLibTest:testHeapEnqueue2(uint256) (runs: 256, μ: 609201, ~: 426081)
MinHeapLibTest:testHeapEnqueueGas() (gas: 856319)
MinHeapLibTest:testHeapPSiftTrick(uint256,uint256,uint256) (runs: 256, μ: 717, ~: 877)
MinHeapLibTest:testHeapPushAndPop(uint256) (runs: 256, μ: 103685, ~: 99755)
MinHeapLibTest:testHeapPushPop(uint256) (runs: 256, μ: 237955, ~: 233915)
MinHeapLibTest:testHeapReplace(uint256) (runs: 256, μ: 296075, ~: 291888)
MinHeapLibTest:testHeapRoot(uint256) (runs: 256, μ: 5232, ~: 5232)
MinHeapLibTest:test__codesize() (gas: 5404)
MinHeapLibTest:testHeapSmallest(uint256) (runs: 256, μ: 1648506, ~: 1287956)
MinHeapLibTest:testHeapSmallestGas() (gas: 49973111)
MinHeapLibTest:test__codesize() (gas: 8139)
MulticallableTest:testMulticallableBenchmark() (gas: 29588)
MulticallableTest:testMulticallableOriginalBenchmark() (gas: 38849)
MulticallableTest:testMulticallablePreservesMsgSender() (gas: 11193)
Expand Down
72 changes: 68 additions & 4 deletions src/utils/MinHeapLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ library MinHeapLib {
/*.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•*/

// Tips:
// - To use as a max-map, negate the values.
// - To use as a max-heap, bitwise negate the values (e.g. `heap.push(~x)`).
// - If use on tuples, pack the tuple values into a single integer.
// - To use on signed integers, convert the signed integers into
// their ordered unsigned counterparts via `uint256(x) + (1 << 255)`.
Expand All @@ -36,14 +36,78 @@ library MinHeapLib {
/// @solidity memory-safe-assembly
assembly {
if iszero(sload(heap.slot)) {
mstore(0x00, 0xa6ca772e) // Store the function selector of `HeapIsEmpty()`.
revert(0x1c, 0x04) // Revert with (offset, size).
mstore(0x00, 0xa6ca772e) // `HeapIsEmpty()`.
revert(0x1c, 0x04)
}
mstore(0x00, heap.slot)
result := sload(keccak256(0x00, 0x20))
}
}

/// @dev Returns an array of the `k` smallest items in the heap,
/// sorted in ascending order, without modifying the heap.
/// If the heap has less than `k` items, all items in the heap will be returned.
function smallest(Heap storage heap, uint256 k) internal view returns (uint256[] memory a) {
/// @solidity memory-safe-assembly
assembly {
function pIndex(h_, p_) -> _i {
_i := mload(add(0x20, add(h_, shl(6, p_))))
}
function pValue(h_, p_) -> _v {
_v := mload(add(h_, shl(6, p_)))
}
function pSet(h_, p_, i_, v_) {
mstore(add(h_, shl(6, p_)), v_)
mstore(add(0x20, add(h_, shl(6, p_))), i_)
}
function pSiftdown(h_, p_, i_, v_) {
for {} 1 {} {
let u_ := shr(1, sub(p_, 1))
if iszero(mul(p_, lt(v_, pValue(h_, u_)))) { break }
pSet(h_, p_, pIndex(h_, u_), pValue(h_, u_))
p_ := u_
}
pSet(h_, p_, i_, v_)
}
function pSiftup(h_, e_, i_, v_) {
let p_ := 0
for { let c_ := 1 } lt(c_, e_) { c_ := add(1, shl(1, p_)) } {
c_ := add(c_, gt(pValue(h_, c_), pValue(h_, add(c_, lt(add(c_, 1), e_)))))
pSet(h_, p_, pIndex(h_, c_), pValue(h_, c_))
p_ := c_
}
pSiftdown(h_, p_, i_, v_)
}
a := mload(0x40)
mstore(0x00, heap.slot)
let sOffset := keccak256(0x00, 0x20)
let o := add(a, 0x20) // Offset into `a`.
let n := sload(heap.slot) // The number of items in the heap.
let m := xor(n, mul(xor(n, k), lt(k, n))) // `min(k, n)`.
let h := add(o, shl(5, m)) // Priority queue.
pSet(h, 0, 0, sload(sOffset)) // Store the root into the priority queue.
for { let e := iszero(eq(o, h)) } e {} {
mstore(o, pValue(h, 0))
o := add(0x20, o)
if eq(o, h) { break }
let childPos := add(shl(1, pIndex(h, 0)), 1)
if iszero(lt(childPos, n)) {
e := sub(e, 1)
pSiftup(h, e, pIndex(h, e), pValue(h, e))
continue
}
pSiftup(h, e, childPos, sload(add(sOffset, childPos)))
childPos := add(1, childPos)
if iszero(eq(childPos, n)) {
pSiftdown(h, e, childPos, sload(add(sOffset, childPos)))
e := add(e, 1)
}
}
mstore(a, shr(5, sub(o, add(a, 0x20)))) // Store the length.
mstore(0x40, o) // Allocate memory.
}
}

/// @dev Returns the number of items in the heap.
function length(Heap storage heap) internal view returns (uint256) {
return heap.data.length;
Expand Down Expand Up @@ -180,7 +244,7 @@ library MinHeapLib {
let child := sload(add(sOffset, childPos))
let rightPos := add(childPos, 1)
let right := sload(add(sOffset, rightPos))
if iszero(and(lt(rightPos, n), iszero(lt(child, right)))) {
if iszero(gt(lt(rightPos, n), lt(child, right))) {
right := child
rightPos := childPos
}
Expand Down
165 changes: 157 additions & 8 deletions test/MinHeapLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -41,38 +41,171 @@ contract MinHeapLibTest is SoladyTest {
function testHeapPushPop(uint256) public {
unchecked {
uint256 n = _random() % 8;
uint256[] memory a = new uint256[](n + 1);
for (uint256 i; i < n; ++i) {
uint256 r = _random();
heap0.push(r);
heap1.push(r);
a[i + 1] = r;
}
n = _random() % 8;
for (uint256 i; i < n; ++i) {
uint256 r = _random();
a[0] = r;
LibSort.insertionSort(a);
uint256 popped0 = heap0.pushPop(r);
heap1.push(r);
uint256 popped1 = heap1.pop();
assertEq(popped0, popped1);
}
LibSort.insertionSort(a);
n = heap0.length();
for (uint256 i; i < n; ++i) {
assertEq(heap0.pop(), a[i + 1]);
}
}
}

function testHeapReplace(uint256) public {
unchecked {
uint256 n = _random() % 8 + 1;
uint256[] memory a = new uint256[](n);
for (uint256 i; i < n; ++i) {
uint256 r = _random();
heap0.push(r);
heap1.push(r);
a[i] = r;
}
n = _random() % 8;
for (uint256 i; i < n; ++i) {
uint256 r = _random();
LibSort.insertionSort(a);
a[0] = r;
uint256 popped0 = heap0.replace(r);
uint256 popped1 = heap1.pop();
heap1.push(r);
assertEq(popped0, popped1);
}
LibSort.insertionSort(a);
n = heap0.length();
for (uint256 i; i < n; ++i) {
assertEq(heap0.pop(), a[i]);
}
}
}

function testHeapSmallest(uint256) public brutalizeMemory {
unchecked {
uint256 n = _random() & 15 == 0 ? _random() % 256 : _random() % 32;
for (uint256 i; i < n; ++i) {
heap0.push(_random());
}
if (_random() & 7 == 0) {
n = _random() % 32;
for (uint256 i; i < n; ++i) {
heap0.pushPop(_random());
if (_random() & 1 == 0) {
heap0.push(_random());
if (_random() & 1 == 0) heap0.pop();
}
if (_random() & 1 == 0) if (heap0.length() != 0) heap0.replace(_random());
}
}
uint256 k = _random() & 15 == 0 ? _random() % 256 : _random() % 32;
k = _random() & 31 == 0 ? 1 << 255 : k;
if (_random() & 7 == 0) _brutalizeMemory();
uint256[] memory computed = heap0.smallest(k);
_checkMemory();
if (_random() & 7 == 0) _brutalizeMemory();
assertEq(computed, _smallest(heap0.data, k));
}
}

function testHeapSmallestGas() public {
unchecked {
for (uint256 i; i < 2048; ++i) {
heap0.push(_random());
}
uint256 gasBefore = gasleft();
heap0.smallest(512);
uint256 gasUsed = gasBefore - gasleft();
emit LogUint("gasUsed", gasUsed);
}
}

function _smallest(uint256[] memory a, uint256 n)
internal
view
returns (uint256[] memory result)
{
result = _copy(a);
LibSort.insertionSort(result);
uint256 k = _min(n, result.length);
/// @solidity memory-safe-assembly
assembly {
mstore(result, k)
}
}

function _copy(uint256[] memory a) private view returns (uint256[] memory b) {
/// @solidity memory-safe-assembly
assembly {
b := mload(0x40)
let n := add(shl(5, mload(a)), 0x20)
pop(staticcall(gas(), 4, a, n, b, n))
mstore(0x40, add(b, n))
}
}

function _min(uint256 a, uint256 b) private pure returns (uint256) {
return a < b ? a : b;
}

function testHeapPSiftTrick(uint256 c, uint256 h, uint256 e) public {
assertEq(_heapPSiftTrick(c, h, e), _heapPSiftTrickOriginal(c, h, e));
}

function _heapPSiftTrick(uint256 c, uint256 h, uint256 e)
internal
pure
returns (uint256 result)
{
/// @solidity memory-safe-assembly
assembly {
function pValue(h_, p_) -> _v {
mstore(0x00, h_)
mstore(0x20, p_)
_v := keccak256(0x00, 0x40)
}
if lt(c, e) {
c := add(c, gt(pValue(h, c), pValue(h, add(c, lt(add(c, 1), e)))))
result := c
}
}
}

function _heapPSiftTrickOriginal(uint256 childPos, uint256 sOffset, uint256 n)
internal
pure
returns (uint256 result)
{
/// @solidity memory-safe-assembly
assembly {
function pValue(h_, p_) -> _v {
mstore(0x00, h_)
mstore(0x20, p_)
_v := keccak256(0x00, 0x40)
}
if lt(childPos, n) {
let child := pValue(sOffset, childPos)
let rightPos := add(childPos, 1)
let right := pValue(sOffset, rightPos)
if iszero(and(lt(rightPos, n), iszero(lt(child, right)))) {
right := child
rightPos := childPos
}
result := rightPos
}
}
}

Expand Down Expand Up @@ -113,18 +246,34 @@ contract MinHeapLibTest is SoladyTest {
}
}

function testHeapEnqueueGas(uint256) public {
function testHeapEnqueue2(uint256) public {
unchecked {
for (uint256 i; i < 16; ++i) {
this.enqueue(i, 8);
}
for (uint256 i; i < 16; ++i) {
this.enqueue(_random() % 16, 8);
uint256 maxLength = _random() & 31 == 0 ? 1 << 255 : _random() % 32 + 1;
uint256 m = _random() % 32 + 1;
for (uint256 i; i < m; ++i) {
uint256 r = _random();
heap0.enqueue(r, maxLength);
heap1.push(r);
if (heap1.length() > maxLength) heap1.pop();
}
uint256 k = _random() % m;
k = _random() & 31 == 0 ? 1 << 255 : k;
assertEq(heap0.smallest(k), heap1.smallest(k));
}
}

function enqueue(uint256 value, uint256 maxLength) public {
heap0.enqueue(value, maxLength);
function testHeapEnqueueGas() public {
unchecked {
for (uint256 t = 8; t < 16; ++t) {
uint256 maxLength = t;
for (uint256 i; i < 16; ++i) {
heap0.enqueue(i, maxLength);
}
for (uint256 i; i < 16; ++i) {
heap0.enqueue(_random() % 16, maxLength);
}
}
while (heap0.length() != 0) heap0.pop();
}
}
}

0 comments on commit 5ddc201

Please sign in to comment.