Skip to content

Commit 368ad1a

Browse files
committed
acp-118
Signed-off-by: Joshua Kim <20001595+joshua-kim@users.noreply.github.com>
1 parent e7648e5 commit 368ad1a

File tree

5 files changed

+665
-0
lines changed

5 files changed

+665
-0
lines changed

network/acp118/aggregator.go

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
2+
// See the file LICENSE for licensing terms.
3+
4+
package acp118
5+
6+
import (
7+
"context"
8+
"errors"
9+
"fmt"
10+
"sync"
11+
12+
"go.uber.org/zap"
13+
"golang.org/x/sync/semaphore"
14+
"google.golang.org/protobuf/proto"
15+
16+
"github.com/ava-labs/avalanchego/ids"
17+
"github.com/ava-labs/avalanchego/network/p2p"
18+
"github.com/ava-labs/avalanchego/proto/pb/sdk"
19+
"github.com/ava-labs/avalanchego/utils/crypto/bls"
20+
"github.com/ava-labs/avalanchego/utils/logging"
21+
"github.com/ava-labs/avalanchego/utils/set"
22+
"github.com/ava-labs/avalanchego/vms/platformvm/warp"
23+
)
24+
25+
var (
26+
ErrDuplicateValidator = errors.New("duplicate validator")
27+
ErrInsufficientSignatures = errors.New("failed to aggregate sufficient stake weight of signatures")
28+
)
29+
30+
type result struct {
31+
message *warp.Message
32+
err error
33+
}
34+
35+
type Validator struct {
36+
NodeID ids.NodeID
37+
PublicKey *bls.PublicKey
38+
Weight uint64
39+
}
40+
41+
type indexedValidator struct {
42+
Validator
43+
I int
44+
}
45+
46+
// NewSignatureAggregator returns an instance of SignatureAggregator
47+
func NewSignatureAggregator(
48+
log logging.Logger,
49+
client *p2p.Client,
50+
maxPending int,
51+
) *SignatureAggregator {
52+
return &SignatureAggregator{
53+
log: log,
54+
client: client,
55+
maxPending: int64(maxPending),
56+
}
57+
}
58+
59+
// SignatureAggregator aggregates validator signatures for warp messages
60+
type SignatureAggregator struct {
61+
log logging.Logger
62+
client *p2p.Client
63+
maxPending int64
64+
}
65+
66+
// AggregateSignatures blocks until stakeWeightThreshold of validators signs the
67+
// provided message. Validators are issued requests in the caller-specified
68+
// order.
69+
func (s *SignatureAggregator) AggregateSignatures(
70+
parentCtx context.Context,
71+
message *warp.UnsignedMessage,
72+
justification []byte,
73+
validators []Validator,
74+
stakeWeightThreshold uint64,
75+
) (*warp.Message, error) {
76+
ctx, cancel := context.WithCancel(parentCtx)
77+
defer cancel()
78+
79+
request := &sdk.SignatureRequest{
80+
Message: message.Bytes(),
81+
Justification: justification,
82+
}
83+
84+
requestBytes, err := proto.Marshal(request)
85+
if err != nil {
86+
return nil, fmt.Errorf("failed to marshal signature request: %w", err)
87+
}
88+
89+
done := make(chan result)
90+
pendingRequests := semaphore.NewWeighted(s.maxPending)
91+
lock := &sync.Mutex{}
92+
aggregatedStakeWeight := uint64(0)
93+
attemptedStakeWeight := uint64(0)
94+
totalStakeWeight := uint64(0)
95+
signatures := make([]*bls.Signature, 0)
96+
signerBitSet := set.NewBits()
97+
98+
nodeIDsToValidator := make(map[ids.NodeID]indexedValidator)
99+
for i, v := range validators {
100+
totalStakeWeight += v.Weight
101+
102+
// Sanity check the validator set provided by the caller
103+
if _, ok := nodeIDsToValidator[v.NodeID]; ok {
104+
return nil, fmt.Errorf("%w: %s", ErrDuplicateValidator, v.NodeID)
105+
}
106+
107+
nodeIDsToValidator[v.NodeID] = indexedValidator{
108+
I: i,
109+
Validator: v,
110+
}
111+
}
112+
113+
onResponse := func(
114+
_ context.Context,
115+
nodeID ids.NodeID,
116+
responseBytes []byte,
117+
err error,
118+
) {
119+
// We are guaranteed a response from a node in the validator set
120+
validator := nodeIDsToValidator[nodeID]
121+
122+
defer func() {
123+
lock.Lock()
124+
attemptedStakeWeight += validator.Weight
125+
remainingStakeWeight := totalStakeWeight - attemptedStakeWeight
126+
failed := remainingStakeWeight < stakeWeightThreshold
127+
lock.Unlock()
128+
129+
if failed {
130+
done <- result{err: ErrInsufficientSignatures}
131+
}
132+
133+
pendingRequests.Release(1)
134+
}()
135+
136+
if err != nil {
137+
s.log.Debug(
138+
"dropping response",
139+
zap.Stringer("nodeID", nodeID),
140+
zap.Error(err),
141+
)
142+
return
143+
}
144+
145+
response := &sdk.SignatureResponse{}
146+
if err := proto.Unmarshal(responseBytes, response); err != nil {
147+
s.log.Debug(
148+
"dropping response",
149+
zap.Stringer("nodeID", nodeID),
150+
zap.Error(err),
151+
)
152+
return
153+
}
154+
155+
signature, err := bls.SignatureFromBytes(response.Signature)
156+
if err != nil {
157+
s.log.Debug(
158+
"dropping response",
159+
zap.Stringer("nodeID", nodeID),
160+
zap.String("reason", "invalid signature"),
161+
zap.Error(err),
162+
)
163+
return
164+
}
165+
166+
if !bls.Verify(validator.PublicKey, signature, message.Bytes()) {
167+
s.log.Debug(
168+
"dropping response",
169+
zap.Stringer("nodeID", nodeID),
170+
zap.String("reason", "public key failed verification"),
171+
)
172+
return
173+
}
174+
175+
lock.Lock()
176+
signerBitSet.Add(validator.I)
177+
signatures = append(signatures, signature)
178+
aggregatedStakeWeight += validator.Weight
179+
180+
if aggregatedStakeWeight >= stakeWeightThreshold {
181+
aggregateSignature, err := bls.AggregateSignatures(signatures)
182+
if err != nil {
183+
done <- result{err: err}
184+
lock.Unlock()
185+
return
186+
}
187+
188+
bitSetSignature := &warp.BitSetSignature{
189+
Signers: signerBitSet.Bytes(),
190+
Signature: [bls.SignatureLen]byte{},
191+
}
192+
193+
copy(bitSetSignature.Signature[:], bls.SignatureToBytes(aggregateSignature))
194+
signedMessage, err := warp.NewMessage(message, bitSetSignature)
195+
done <- result{message: signedMessage, err: err}
196+
lock.Unlock()
197+
return
198+
}
199+
200+
lock.Unlock()
201+
}
202+
203+
for _, validator := range validators {
204+
if err := pendingRequests.Acquire(ctx, 1); err != nil {
205+
return nil, err
206+
}
207+
208+
// Avoid loop shadowing in goroutine
209+
validatorCopy := validator
210+
go func() {
211+
if err := s.client.AppRequest(
212+
ctx,
213+
set.Of(validatorCopy.NodeID),
214+
requestBytes,
215+
onResponse,
216+
); err != nil {
217+
done <- result{err: err}
218+
return
219+
}
220+
}()
221+
}
222+
223+
select {
224+
case <-ctx.Done():
225+
return nil, ctx.Err()
226+
case r := <-done:
227+
return r.message, r.err
228+
}
229+
}

0 commit comments

Comments
 (0)