From d7c13bd1601e28ef67d37821db803b9582b0819b Mon Sep 17 00:00:00 2001 From: Louis Liu <35095310+louisliu2048@users.noreply.github.com> Date: Wed, 31 Jul 2024 19:43:25 +0800 Subject: [PATCH] add counter cache to opt process (#882) * add counter cache to opt process * fix method to calculate Combined Counters --------- Co-authored-by: Valentin Staykov <79150443+V-Staykov@users.noreply.github.com> --- core/vm/zk_batch_counters.go | 54 ++++++++++++++----- core/vm/zk_counters.go | 10 ++++ .../stage_sequence_execute_transactions.go | 1 + zk/tests/zk_counters_test.go | 2 + 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/core/vm/zk_batch_counters.go b/core/vm/zk_batch_counters.go index 5640ef2c93a..49bd91261e8 100644 --- a/core/vm/zk_batch_counters.go +++ b/core/vm/zk_batch_counters.go @@ -18,12 +18,20 @@ type BatchCounterCollector struct { forkId uint16 unlimitedCounters bool addonCounters *Counters + + rlpCombinedCounters Counters + executionCombinedCounters Counters + processingCombinedCounters Counters + + rlpCombinedCountersCache Counters + executionCombinedCountersCache Counters + processingCombinedCountersCache Counters } func NewBatchCounterCollector(smtMaxLevel int, forkId uint16, mcpReduction float64, unlimitedCounters bool, addonCounters *Counters) *BatchCounterCollector { smtLevels := calculateSmtLevels(smtMaxLevel, 0, mcpReduction) smtLevelsForTransaction := calculateSmtLevels(smtMaxLevel, 32, mcpReduction) - return &BatchCounterCollector{ + bcc := BatchCounterCollector{ transactions: []*TransactionCounter{}, smtLevels: smtLevels, smtLevelsForTransaction: smtLevelsForTransaction, @@ -32,6 +40,12 @@ func NewBatchCounterCollector(smtMaxLevel int, forkId uint16, mcpReduction float unlimitedCounters: unlimitedCounters, addonCounters: addonCounters, } + + bcc.rlpCombinedCounters = bcc.NewCounters() + bcc.executionCombinedCounters = bcc.NewCounters() + bcc.processingCombinedCounters = bcc.NewCounters() + + return &bcc } func (bcc *BatchCounterCollector) Clone() *BatchCounterCollector { @@ -55,6 +69,10 @@ func (bcc *BatchCounterCollector) Clone() *BatchCounterCollector { blockCount: bcc.blockCount, forkId: bcc.forkId, unlimitedCounters: bcc.unlimitedCounters, + + rlpCombinedCounters: bcc.rlpCombinedCounters.Clone(), + executionCombinedCounters: bcc.executionCombinedCounters.Clone(), + processingCombinedCounters: bcc.processingCombinedCounters.Clone(), } } @@ -69,6 +87,7 @@ func (bcc *BatchCounterCollector) AddNewTransactionCounters(txCounters *Transact } bcc.transactions = append(bcc.transactions, txCounters) + bcc.UpdateRlpCountersCache(txCounters) return bcc.CheckForOverflow(false) //no need to calculate the merkle proof here } @@ -202,19 +221,10 @@ func (bcc *BatchCounterCollector) CombineCollectors(verifyMerkleProof bool) (Cou } } - for _, tx := range bcc.transactions { - for k, v := range tx.rlpCounters.counters { - combined[k].used += v.used - combined[k].remaining -= v.used - } - for k, v := range tx.executionCounters.counters { - combined[k].used += v.used - combined[k].remaining -= v.used - } - for k, v := range tx.processingCounters.counters { - combined[k].used += v.used - combined[k].remaining -= v.used - } + for k, _ := range combined { + val := bcc.rlpCombinedCounters[k].used + bcc.executionCombinedCounters[k].used + bcc.processingCombinedCounters[k].used + combined[k].used += val + combined[k].remaining -= val } return combined, nil @@ -260,3 +270,19 @@ func (bcc *BatchCounterCollector) CombineCollectorsNoChanges(verifyMerkleProof b return combined } + +func (bcc *BatchCounterCollector) UpdateRlpCountersCache(txCounters *TransactionCounter) { + for k, v := range txCounters.rlpCounters.counters { + bcc.rlpCombinedCounters[k].used += v.used + } +} + +func (bcc *BatchCounterCollector) UpdateExecutionAndProcessingCountersCache(txCounters *TransactionCounter) { + for k, v := range txCounters.executionCounters.counters { + bcc.executionCombinedCounters[k].used += v.used + } + + for k, v := range txCounters.processingCounters.counters { + bcc.processingCombinedCounters[k].used += v.used + } +} diff --git a/core/vm/zk_counters.go b/core/vm/zk_counters.go index 0329a2d8883..953e0b109fd 100644 --- a/core/vm/zk_counters.go +++ b/core/vm/zk_counters.go @@ -134,6 +134,16 @@ func (c *Counters) GetPoseidonPaddings() *Counter { return (*c)[D] } +func (cc Counters) Clone() Counters { + var clonedCounters Counters = Counters{} + + for k, v := range cc { + clonedCounters[k] = v.Clone() + } + + return clonedCounters +} + type CounterKey string var ( diff --git a/zk/stages/stage_sequence_execute_transactions.go b/zk/stages/stage_sequence_execute_transactions.go index 37be7b898c3..1c8fff58c32 100644 --- a/zk/stages/stage_sequence_execute_transactions.go +++ b/zk/stages/stage_sequence_execute_transactions.go @@ -218,6 +218,7 @@ func attemptAddTransaction( return nil, nil, false, err } + batchCounters.UpdateExecutionAndProcessingCountersCache(txCounters) // now that we have executed we can check again for an overflow if overflow, err = batchCounters.CheckForOverflow(l1InfoIndex != 0); err != nil { return nil, nil, false, err diff --git a/zk/tests/zk_counters_test.go b/zk/tests/zk_counters_test.go index d47b4a7b6a2..9f494f7de82 100644 --- a/zk/tests/zk_counters_test.go +++ b/zk/tests/zk_counters_test.go @@ -337,6 +337,8 @@ func runTest(t *testing.T, test vector, err error, fileName string, idx int) { if err = txCounters.ProcessTx(ibs, result.ReturnData); err != nil { t.Fatal(err) } + + batchCollector.UpdateExecutionAndProcessingCountersCache(txCounters) } }