Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MinObservation to make observations per observer unique [CCIP-5050] #495

Merged
merged 11 commits into from
Feb 4, 2025
Prev Previous commit
Next Next commit
Make new OracleMinObservation for exec plugin consensus
  • Loading branch information
asoliman92 committed Feb 4, 2025
commit fd06a9c1d080e83d5667c513c7377d58f14c44a3
24 changes: 12 additions & 12 deletions execute/plugin_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ func mergeMessageObservations(
aos []plugincommon.AttributedObservation[exectypes.Observation], fChain map[cciptypes.ChainSelector]int,
) exectypes.MessageObservations {
// Create a validator for each chain
validators := make(map[cciptypes.ChainSelector]consensus.MinObservation[cciptypes.Message])
validators := make(map[cciptypes.ChainSelector]consensus.OracleMinObservation[cciptypes.Message])
for selector, f := range fChain {
validators[selector] = consensus.NewMinObservation[cciptypes.Message](consensus.FPlus1(f), nil)
validators[selector] = consensus.NewOracleMinObservation[cciptypes.Message](consensus.FPlus1(f), nil)
}

// Add messages to the validator for each chain selector.
Expand Down Expand Up @@ -343,10 +343,10 @@ func mergeCommitObservations(
aos []plugincommon.AttributedObservation[exectypes.Observation], fChain map[cciptypes.ChainSelector]int,
) exectypes.CommitObservations {
// Create a validator for each chain
validators := make(map[cciptypes.ChainSelector]consensus.MinObservation[exectypes.CommitData])
validators := make(map[cciptypes.ChainSelector]consensus.OracleMinObservation[exectypes.CommitData])
for selector, f := range fChain {
validators[selector] =
consensus.NewMinObservation[exectypes.CommitData](consensus.FPlus1(f), nil)
consensus.NewOracleMinObservation[exectypes.CommitData](consensus.FPlus1(f), nil)
}

// Add reports to the validator for each chain selector.
Expand Down Expand Up @@ -384,7 +384,7 @@ func mergeMessageHashes(
fChain map[cciptypes.ChainSelector]int,
) exectypes.MessageHashes {
// Single message can transfer multiple tokens, so we need to find consensus on the token level.
validators := make(map[cciptypes.ChainSelector]map[cciptypes.SeqNum]consensus.MinObservation[cciptypes.Bytes32])
validators := make(map[cciptypes.ChainSelector]map[cciptypes.SeqNum]consensus.OracleMinObservation[cciptypes.Bytes32])
results := make(exectypes.MessageHashes)

for _, ao := range aos {
Expand All @@ -400,13 +400,13 @@ func mergeMessageHashes(
}

if _, ok1 := validators[selector]; !ok1 {
validators[selector] = make(map[cciptypes.SeqNum]consensus.MinObservation[cciptypes.Bytes32])
validators[selector] = make(map[cciptypes.SeqNum]consensus.OracleMinObservation[cciptypes.Bytes32])
}

for seqNr, hash := range seqMap {
if _, ok := validators[selector][seqNr]; !ok {
validators[selector][seqNr] =
consensus.NewMinObservation[cciptypes.Bytes32](consensus.FPlus1(f), nil)
consensus.NewOracleMinObservation[cciptypes.Bytes32](consensus.FPlus1(f), nil)
}
validators[selector][seqNr].Add(hash, ao.OracleID)
}
Expand Down Expand Up @@ -435,7 +435,7 @@ func mergeTokenObservations(
fChain map[cciptypes.ChainSelector]int,
) exectypes.TokenDataObservations {
// Single message can transfer multiple tokens, so we need to find consensus on the token level.
validators := make(map[cciptypes.ChainSelector]map[reader.MessageTokenID]consensus.MinObservation[exectypes.TokenData])
validators := make(map[cciptypes.ChainSelector]map[reader.MessageTokenID]consensus.OracleMinObservation[exectypes.TokenData])
results := make(exectypes.TokenDataObservations)

for _, ao := range aos {
Expand All @@ -451,7 +451,7 @@ func mergeTokenObservations(
}

if _, ok1 := validators[selector]; !ok1 {
validators[selector] = make(map[reader.MessageTokenID]consensus.MinObservation[exectypes.TokenData])
validators[selector] = make(map[reader.MessageTokenID]consensus.OracleMinObservation[exectypes.TokenData])
}

initResultsAndValidators(selector, f, seqMap, results, validators, ao.OracleID)
Expand Down Expand Up @@ -483,7 +483,7 @@ func initResultsAndValidators(
f int,
seqMap map[cciptypes.SeqNum]exectypes.MessageTokenData,
results exectypes.TokenDataObservations,
validators map[cciptypes.ChainSelector]map[reader.MessageTokenID]consensus.MinObservation[exectypes.TokenData],
validators map[cciptypes.ChainSelector]map[reader.MessageTokenID]consensus.OracleMinObservation[exectypes.TokenData],
oracleID commontypes.OracleID,
) {
for seqNr, msgTokenData := range seqMap {
Expand All @@ -495,7 +495,7 @@ func initResultsAndValidators(
messageTokenID := reader.NewMessageTokenID(seqNr, tokenIndex)
if _, ok := validators[selector][messageTokenID]; !ok {
validators[selector][messageTokenID] =
consensus.NewMinObservation[exectypes.TokenData](consensus.FPlus1(f), exectypes.TokenDataHash)
consensus.NewOracleMinObservation[exectypes.TokenData](consensus.FPlus1(f), exectypes.TokenDataHash)
}
validators[selector][messageTokenID].Add(tokenData, oracleID)
}
Expand All @@ -516,7 +516,7 @@ func mergeNonceObservations(
}

// Create one validator because nonces are only observed from the destination chain.
validator := consensus.NewMinObservation[NonceTriplet](consensus.FPlus1(fChainDest), nil)
validator := consensus.NewOracleMinObservation[NonceTriplet](consensus.FPlus1(fChainDest), nil)

// Add reports to the validator for each chain selector.
for _, ao := range daos {
Expand Down
20 changes: 7 additions & 13 deletions internal/plugincommon/consensus/min_observation.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package consensus

import (
"fmt"
"github.com/smartcontractkit/libocr/commontypes"

"golang.org/x/crypto/sha3"

Expand All @@ -14,16 +13,11 @@ type counter[T any] struct {
count uint
}

type observersCounter[T any] struct {
data T
observers map[commontypes.OracleID]struct{}
}

// MinObservation provides a way to ensure a minimum number of observations for
// some piece of data have occurred. It maintains an internal cache and provides a list
// of valid or invalid data points.
type MinObservation[T any] interface {
Add(data T, oracleID commontypes.OracleID)
Add(data T)
GetValid() []T
}

Expand All @@ -32,7 +26,7 @@ type MinObservation[T any] interface {
// with one another, and whether they meet the required count threshold.
type minObservation[T any] struct {
minObservation Threshold
cache map[cciptypes.Bytes32]*observersCounter[T]
cache map[cciptypes.Bytes32]*counter[T]
idFunc func(T) [32]byte
}

Expand All @@ -46,24 +40,24 @@ func NewMinObservation[T any](minThreshold Threshold, idFunc func(T) [32]byte) M
}
return &minObservation[T]{
minObservation: minThreshold,
cache: make(map[cciptypes.Bytes32]*observersCounter[T]),
cache: make(map[cciptypes.Bytes32]*counter[T]),
idFunc: idFunc,
}
}

func (cv *minObservation[T]) Add(data T, oracleID commontypes.OracleID) {
func (cv *minObservation[T]) Add(data T) {
id := cv.idFunc(data)
if _, ok := cv.cache[id]; ok {
cv.cache[id].observers[oracleID] = struct{}{}
cv.cache[id].count++
} else {
cv.cache[id] = &observersCounter[T]{data: data, observers: make(map[commontypes.OracleID]struct{})}
cv.cache[id] = &counter[T]{data: data, count: 1}
}
}

func (cv *minObservation[T]) GetValid() []T {
var validated []T
for _, rc := range cv.cache {
if len(rc.observers) >= int(cv.minObservation) {
if rc.count >= uint(cv.minObservation) {
rc := rc
validated = append(validated, rc.data)
}
Expand Down
65 changes: 65 additions & 0 deletions internal/plugincommon/consensus/oracle_min_observation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package consensus

import (
"fmt"
cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
"github.com/smartcontractkit/libocr/commontypes"
"golang.org/x/crypto/sha3"
)

type observersCounter[T any] struct {
data T
observers map[commontypes.OracleID]struct{}
}

// OracleMinObservation provides a way to ensure a minimum number of observations for
// some piece of data have occurred. It maintains an internal cache and provides a list
// of valid or invalid data points.
type OracleMinObservation[T any] interface {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API doc comment should outline why this would be used instead of MinObservation. I think you can say that OracleMinObservation ensures that duplicate observations from the same observer are only counted once? TBH this should be the default behavior, I'm not even sure we should have a separate MinObservation and OracleMinObservation.

cc @winder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in DM I agree we should have one. It's a big change that will take some time and other plugins already use maps for most observations which by design prevents duplicates.

Add(data T, oracleID commontypes.OracleID)
GetValid() []T
asoliman92 marked this conversation as resolved.
Show resolved Hide resolved
}

// minObservation is a helper object to filter data based on observation counts.
// It keeps track of all inputs, determines if they are consistent
// with one another, and whether they meet the required count threshold.
asoliman92 marked this conversation as resolved.
Show resolved Hide resolved
type oracleMinObservation[T any] struct {
minObservation Threshold
cache map[cciptypes.Bytes32]*observersCounter[T]
idFunc func(T) [32]byte
}

// NewOracleMinObservation constructs a concrete MinObservation object. The
// supplied idFunc is used to generate a uniqueID for the type being observed.
asoliman92 marked this conversation as resolved.
Show resolved Hide resolved
func NewOracleMinObservation[T any](minThreshold Threshold, idFunc func(T) [32]byte) OracleMinObservation[T] {
if idFunc == nil {
idFunc = func(data T) [32]byte {
return sha3.Sum256([]byte(fmt.Sprintf("%v", data)))
}
}
Comment on lines +39 to +43
Copy link
Contributor

@winder winder Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud, but could this version be implemented by overloading the ID func?

	if idFunc == nil {
		idFunc = func(oracle int, data T) [32]byte {
			return sha3.Sum256([]byte(fmt.Sprintf("%d_%v", oracle, data)))
		}
	} else {
		idFunc = func(oracle int, data T) [32]byte {
			return sha3.Sum256([]byte(fmt.Sprintf("%d_%v", oracle, idFunc(data))))
		}
	}

Copy link
Contributor Author

@asoliman92 asoliman92 Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought about this and tried it, was't straightforward to use from the callers' though. don't remember the specific reason for it now.

return &oracleMinObservation[T]{
minObservation: minThreshold,
cache: make(map[cciptypes.Bytes32]*observersCounter[T]),
idFunc: idFunc,
}
}

func (cv *oracleMinObservation[T]) Add(data T, oracleID commontypes.OracleID) {
id := cv.idFunc(data)
if _, ok := cv.cache[id]; ok {
cv.cache[id].observers[oracleID] = struct{}{}
} else {
cv.cache[id] = &observersCounter[T]{data: data, observers: make(map[commontypes.OracleID]struct{})}
}
}

func (cv *oracleMinObservation[T]) GetValid() []T {
var validated []T
for _, rc := range cv.cache {
if len(rc.observers) >= int(cv.minObservation) {
rc := rc
validated = append(validated, rc.data)
}
}
return validated
}
137 changes: 137 additions & 0 deletions internal/plugincommon/consensus/oracle_min_observation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package consensus_test

import (
"fmt"
"github.com/smartcontractkit/libocr/commontypes"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/crypto/sha3"

"github.com/smartcontractkit/chainlink-ccip/execute/exectypes"
"github.com/smartcontractkit/chainlink-ccip/internal/plugincommon/consensus"
cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)

func Test_CommitReportValidator_Oracle_ExecutePluginCommitData(t *testing.T) {
tests := []struct {
name string
min consensus.Threshold
reports []exectypes.CommitData
valid []exectypes.CommitData
}{
{
name: "empty",
valid: nil,
},
{
name: "single report, enough observations",
min: 1,
reports: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}},
},
valid: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}},
},
},
{
name: "single report, not enough observations",
min: 2,
reports: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}},
},
valid: nil,
},
{
name: "multiple reports, partial observations",
min: 2,
reports: []exectypes.CommitData{
{MerkleRoot: [32]byte{3}},
{MerkleRoot: [32]byte{1}},
{MerkleRoot: [32]byte{2}},
{MerkleRoot: [32]byte{1}},
{MerkleRoot: [32]byte{2}},
},
valid: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}},
{MerkleRoot: [32]byte{2}},
},
},
{
name: "multiple reports for same root",
min: 2,
reports: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}, BlockNum: 1},
{MerkleRoot: [32]byte{1}, BlockNum: 2},
{MerkleRoot: [32]byte{1}, BlockNum: 3},
{MerkleRoot: [32]byte{1}, BlockNum: 4},
{MerkleRoot: [32]byte{1}, BlockNum: 1},
},
valid: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}, BlockNum: 1},
},
},
{
name: "different executed messages same root",
min: 2,
reports: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{1, 2}},
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{2, 3}},
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{3, 4}},
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{4, 5}},
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{5, 6}},
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{1, 2}},
},
valid: []exectypes.CommitData{
{MerkleRoot: [32]byte{1}, ExecutedMessages: []cciptypes.SeqNum{1, 2}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Initialize the minObservation
idFunc := func(data exectypes.CommitData) [32]byte {
return sha3.Sum256([]byte(fmt.Sprintf("%v", data)))
}
validator := consensus.NewOracleMinObservation[exectypes.CommitData](tt.min, idFunc)
for i, report := range tt.reports {
validator.Add(report, commontypes.OracleID(i))
}

// Test the results
got := validator.GetValid()
if !assert.ElementsMatch(t, got, tt.valid) {
t.Errorf("GetValid() = %v, valid %v", got, tt.valid)
}
})
}
}

func Test_CommitReportValidator_Oracle_Generics(t *testing.T) {
type Generic struct {
number int
}

// Initialize the minObservation
idFunc := func(data Generic) [32]byte {
return sha3.Sum256([]byte(fmt.Sprintf("%v", data)))
}
validator := consensus.NewOracleMinObservation[Generic](2, idFunc)

wantValue := Generic{number: 1}
otherValue := Generic{number: 2}

validator.Add(wantValue, 1)
validator.Add(wantValue, 2)
validator.Add(otherValue, 3)

// Test the results

wantValid := []Generic{wantValue}
got := validator.GetValid()
if !assert.ElementsMatch(t, got, wantValid) {
t.Errorf("GetValid() = %v, valid %v", got, wantValid)
}
}
Loading