From 88b1ac3732a083a049db94000820a7b6da7e1306 Mon Sep 17 00:00:00 2001 From: Ian Shim <100327837+ian-shim@users.noreply.github.com> Date: Tue, 19 Nov 2024 16:56:02 -0800 Subject: [PATCH] [v2] node/relay bug fixes (#908) --- api/clients/relay_client.go | 92 +++++++++++++++++++-------------- node/config.go | 3 ++ node/flags/flags.go | 7 +++ node/grpc/server_test.go | 1 + node/grpc/server_v2.go | 16 ++++++ node/grpc/server_v2_test.go | 78 +++++++++++++++------------- node/node.go | 23 ++++++++- node/store_v2.go | 7 +++ relay/metadata_provider.go | 17 ++++-- relay/metadata_provider_test.go | 14 +++-- relay/server.go | 8 +-- relay/server_test.go | 27 ++++++++-- 12 files changed, 202 insertions(+), 91 deletions(-) diff --git a/api/clients/relay_client.go b/api/clients/relay_client.go index 185f5b98e5..5ed6a0ead7 100644 --- a/api/clients/relay_client.go +++ b/api/clients/relay_client.go @@ -45,11 +45,15 @@ type RelayClient interface { type relayClient struct { config *RelayClientConfig - initOnce map[corev2.RelayKey]*sync.Once - conns map[corev2.RelayKey]*grpc.ClientConn - logger logging.Logger - - grpcClients map[corev2.RelayKey]relaygrpc.RelayClient + // initOnce is used to ensure that the connection to each relay is initialized only once. + // It maps relay key to a sync.Once instance: `map[corev2.RelayKey]*sync.Once` + initOnce *sync.Map + // conns maps relay key to the gRPC connection: `map[corev2.RelayKey]*grpc.ClientConn` + conns sync.Map + logger logging.Logger + + // grpcClients maps relay key to the gRPC client: `map[corev2.RelayKey]relaygrpc.RelayClient` + grpcClients sync.Map } var _ RelayClient = (*relayClient)(nil) @@ -57,37 +61,28 @@ var _ RelayClient = (*relayClient)(nil) // NewRelayClient creates a new RelayClient that connects to the relays specified in the config. // It keeps a connection to each relay and reuses it for subsequent requests, and the connection is lazily instantiated. func NewRelayClient(config *RelayClientConfig, logger logging.Logger) (*relayClient, error) { - if config == nil || len(config.Sockets) > 0 { + if config == nil || len(config.Sockets) <= 0 { return nil, fmt.Errorf("invalid config: %v", config) } - initOnce := make(map[corev2.RelayKey]*sync.Once) - conns := make(map[corev2.RelayKey]*grpc.ClientConn) - grpcClients := make(map[corev2.RelayKey]relaygrpc.RelayClient) + initOnce := sync.Map{} for key := range config.Sockets { - initOnce[key] = &sync.Once{} + initOnce.Store(key, &sync.Once{}) } return &relayClient{ config: config, - initOnce: initOnce, - conns: conns, + initOnce: &initOnce, logger: logger, - - grpcClients: grpcClients, }, nil } func (c *relayClient) GetBlob(ctx context.Context, relayKey corev2.RelayKey, blobKey corev2.BlobKey) ([]byte, error) { - if err := c.initOnceGrpcConnection(relayKey); err != nil { + client, err := c.getClient(relayKey) + if err != nil { return nil, err } - client, ok := c.grpcClients[relayKey] - if !ok { - return nil, fmt.Errorf("no grpc client for relay key: %v", relayKey) - } - res, err := client.GetBlob(ctx, &relaygrpc.GetBlobRequest{ BlobKey: blobKey[:], }) @@ -102,15 +97,11 @@ func (c *relayClient) GetChunksByRange(ctx context.Context, relayKey corev2.Rela if len(requests) == 0 { return nil, fmt.Errorf("no requests") } - if err := c.initOnceGrpcConnection(relayKey); err != nil { + client, err := c.getClient(relayKey) + if err != nil { return nil, err } - client, ok := c.grpcClients[relayKey] - if !ok { - return nil, fmt.Errorf("no grpc client for relay key: %v", relayKey) - } - grpcRequests := make([]*relaygrpc.ChunkRequest, len(requests)) for i, req := range requests { grpcRequests[i] = &relaygrpc.ChunkRequest{ @@ -138,13 +129,10 @@ func (c *relayClient) GetChunksByIndex(ctx context.Context, relayKey corev2.Rela if len(requests) == 0 { return nil, fmt.Errorf("no requests") } - if err := c.initOnceGrpcConnection(relayKey); err != nil { - return nil, err - } - client, ok := c.grpcClients[relayKey] - if !ok { - return nil, fmt.Errorf("no grpc client for relay key: %v", relayKey) + client, err := c.getClient(relayKey) + if err != nil { + return nil, err } grpcRequests := make([]*relaygrpc.ChunkRequest, len(requests)) @@ -169,9 +157,28 @@ func (c *relayClient) GetChunksByIndex(ctx context.Context, relayKey corev2.Rela return res.GetData(), nil } +func (c *relayClient) getClient(key corev2.RelayKey) (relaygrpc.RelayClient, error) { + if err := c.initOnceGrpcConnection(key); err != nil { + return nil, err + } + maybeClient, ok := c.grpcClients.Load(key) + if !ok { + return nil, fmt.Errorf("no grpc client for relay key: %v", key) + } + client, ok := maybeClient.(relaygrpc.RelayClient) + if !ok { + return nil, fmt.Errorf("invalid grpc client for relay key: %v", key) + } + return client, nil +} + func (c *relayClient) initOnceGrpcConnection(key corev2.RelayKey) error { var initErr error - c.initOnce[key].Do(func() { + once, ok := c.initOnce.Load(key) + if !ok { + return fmt.Errorf("unknown relay key: %v", key) + } + once.(*sync.Once).Do(func() { socket, ok := c.config.Sockets[key] if !ok { initErr = fmt.Errorf("unknown relay key: %v", key) @@ -183,24 +190,31 @@ func (c *relayClient) initOnceGrpcConnection(key corev2.RelayKey) error { initErr = err return } - c.conns[key] = conn - c.grpcClients[key] = relaygrpc.NewRelayClient(conn) + c.conns.Store(key, conn) + c.grpcClients.Store(key, relaygrpc.NewRelayClient(conn)) }) return initErr } func (c *relayClient) Close() error { var errList *multierror.Error - for k, conn := range c.conns { + c.conns.Range(func(k, v interface{}) bool { + conn, ok := v.(*grpc.ClientConn) + if !ok { + errList = multierror.Append(errList, fmt.Errorf("invalid connection for relay key: %v", k)) + return true + } + if conn != nil { err := conn.Close() - conn = nil - c.grpcClients[k] = nil + c.conns.Delete(k) + c.grpcClients.Delete(k) if err != nil { c.logger.Error("failed to close connection", "err", err) errList = multierror.Append(errList, err) } } - } + return true + }) return errList.ErrorOrNil() } diff --git a/node/config.go b/node/config.go index 3b83885798..e67fc7894f 100644 --- a/node/config.go +++ b/node/config.go @@ -88,6 +88,8 @@ type Config struct { EthClientConfig geth.EthClientConfig LoggerConfig common.LoggerConfig EncoderConfig kzg.KzgConfig + + EnableV2 bool } // NewConfig parses the Config from the provided flags or environment variables and @@ -232,5 +234,6 @@ func NewConfig(ctx *cli.Context) (*Config, error) { BLSKeyPassword: ctx.GlobalString(flags.BlsKeyPasswordFlag.Name), BLSSignerTLSCertFilePath: ctx.GlobalString(flags.BLSSignerCertFileFlag.Name), BLSRemoteSignerEnabled: blsRemoteSignerEnabled, + EnableV2: ctx.GlobalBool(flags.EnableV2Flag.Name), }, nil } diff --git a/node/flags/flags.go b/node/flags/flags.go index a1829d7acf..5bcd95a98b 100644 --- a/node/flags/flags.go +++ b/node/flags/flags.go @@ -218,6 +218,12 @@ var ( Required: false, EnvVar: common.PrefixEnvVar(EnvVarPrefix, "ENABLE_GNARK_BUNDLE_ENCODING"), } + EnableV2Flag = cli.BoolFlag{ + Name: "enable-v2", + Usage: "Enable V2 features", + Required: false, + EnvVar: common.PrefixEnvVar(EnvVarPrefix, "ENABLE_V2"), + } // Test only, DO NOT USE the following flags in production @@ -346,6 +352,7 @@ var optionalFlags = []cli.Flag{ BLSRemoteSignerUrlFlag, BLSPublicKeyHexFlag, BLSSignerCertFileFlag, + EnableV2Flag, } func init() { diff --git a/node/grpc/server_test.go b/node/grpc/server_test.go index 27819f59e0..8953f1f1cd 100644 --- a/node/grpc/server_test.go +++ b/node/grpc/server_test.go @@ -83,6 +83,7 @@ func makeConfig(t *testing.T) *node.Config { DbPath: t.TempDir(), ID: opID, NumBatchValidators: runtime.GOMAXPROCS(0), + EnableV2: false, } } diff --git a/node/grpc/server_v2.go b/node/grpc/server_v2.go index 9fc53178f3..4f46a70d53 100644 --- a/node/grpc/server_v2.go +++ b/node/grpc/server_v2.go @@ -58,6 +58,13 @@ func (s *ServerV2) NodeInfo(ctx context.Context, in *pb.NodeInfoRequest) (*pb.No } func (s *ServerV2) StoreChunks(ctx context.Context, in *pb.StoreChunksRequest) (*pb.StoreChunksReply, error) { + if !s.config.EnableV2 { + return nil, api.NewErrorInvalidArg("v2 API is disabled") + } + + if s.node.StoreV2 == nil { + return nil, api.NewErrorInternal("v2 store not initialized") + } batch, err := s.validateStoreChunksRequest(in) if err != nil { return nil, err @@ -68,6 +75,7 @@ func (s *ServerV2) StoreChunks(ctx context.Context, in *pb.StoreChunksRequest) ( return nil, api.NewErrorInternal(fmt.Sprintf("invalid batch header: %v", err)) } + s.logger.Info("new StoreChunks request", "batchHeaderHash", hex.EncodeToString(batchHeaderHash[:]), "numBlobs", len(batch.BlobCertificates), "referenceBlockNumber", batch.BatchHeader.ReferenceBlockNumber) operatorState, err := s.node.ChainState.GetOperatorStateByOperator(ctx, uint(batch.BatchHeader.ReferenceBlockNumber), s.node.Config.ID) if err != nil { return nil, err @@ -136,6 +144,14 @@ func (s *ServerV2) validateStoreChunksRequest(req *pb.StoreChunksRequest) (*core } func (s *ServerV2) GetChunks(ctx context.Context, in *pb.GetChunksRequest) (*pb.GetChunksReply, error) { + if !s.config.EnableV2 { + return nil, api.NewErrorInvalidArg("v2 API is disabled") + } + + if s.node.StoreV2 == nil { + return nil, api.NewErrorInternal("v2 store not initialized") + } + blobKey, err := corev2.BytesToBlobKey(in.GetBlobKey()) if err != nil { return nil, api.NewErrorInvalidArg(fmt.Sprintf("invalid blob key: %v", err)) diff --git a/node/grpc/server_v2_test.go b/node/grpc/server_v2_test.go index b83d36e72d..6bb15870c6 100644 --- a/node/grpc/server_v2_test.go +++ b/node/grpc/server_v2_test.go @@ -84,8 +84,21 @@ func TestV2NodeInfoRequest(t *testing.T) { assert.True(t, err == nil) } +func TestV2ServerWithoutV2(t *testing.T) { + config := makeConfig(t) + config.EnableV2 = false + c := newTestComponents(t, config) + _, err := c.server.StoreChunks(context.Background(), &pbv2.StoreChunksRequest{}) + requireErrorStatus(t, err, codes.InvalidArgument) + + _, err = c.server.GetChunks(context.Background(), &pbv2.GetChunksRequest{}) + requireErrorStatus(t, err, codes.InvalidArgument) +} + func TestV2StoreChunksInputValidation(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) _, batch, _ := nodemock.MockBatch(t) batchProto, err := batch.ToProtobuf() require.NoError(t, err) @@ -94,10 +107,7 @@ func TestV2StoreChunksInputValidation(t *testing.T) { Batch: &pbcommon.Batch{}, } _, err = c.server.StoreChunks(context.Background(), req) - require.Error(t, err) - s, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.InvalidArgument) + requireErrorStatus(t, err, codes.InvalidArgument) req = &pbv2.StoreChunksRequest{ Batch: &pbcommon.Batch{ @@ -106,10 +116,7 @@ func TestV2StoreChunksInputValidation(t *testing.T) { }, } _, err = c.server.StoreChunks(context.Background(), req) - require.Error(t, err) - s, ok = status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.InvalidArgument) + requireErrorStatus(t, err, codes.InvalidArgument) req = &pbv2.StoreChunksRequest{ Batch: &pbcommon.Batch{ @@ -118,14 +125,13 @@ func TestV2StoreChunksInputValidation(t *testing.T) { }, } _, err = c.server.StoreChunks(context.Background(), req) - require.Error(t, err) - s, ok = status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.InvalidArgument) + requireErrorStatus(t, err, codes.InvalidArgument) } func TestV2StoreChunksSuccess(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) blobKeys, batch, bundles := nodemock.MockBatch(t) batchProto, err := batch.ToProtobuf() @@ -176,7 +182,9 @@ func TestV2StoreChunksSuccess(t *testing.T) { } func TestV2StoreChunksDownloadFailure(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) _, batch, _ := nodemock.MockBatch(t) batchProto, err := batch.ToProtobuf() @@ -191,14 +199,13 @@ func TestV2StoreChunksDownloadFailure(t *testing.T) { Batch: batchProto, }) require.Nil(t, reply.GetSignature()) - require.Error(t, err) - s, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.Internal) + requireErrorStatus(t, err, codes.Internal) } func TestV2StoreChunksStorageFailure(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) blobKeys, batch, bundles := nodemock.MockBatch(t) batchProto, err := batch.ToProtobuf() @@ -238,14 +245,13 @@ func TestV2StoreChunksStorageFailure(t *testing.T) { Batch: batchProto, }) require.Nil(t, reply.GetSignature()) - require.Error(t, err) - s, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.Internal) + requireErrorStatus(t, err, codes.Internal) } func TestV2StoreChunksValidationFailure(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) blobKeys, batch, bundles := nodemock.MockBatch(t) batchProto, err := batch.ToProtobuf() @@ -286,25 +292,21 @@ func TestV2StoreChunksValidationFailure(t *testing.T) { Batch: batchProto, }) require.Nil(t, reply.GetSignature()) - require.Error(t, err) - s, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.Internal) + requireErrorStatus(t, err, codes.Internal) c.store.AssertCalled(t, "DeleteKeys", mock.Anything, mock.Anything) } func TestV2GetChunksInputValidation(t *testing.T) { - c := newTestComponents(t, makeConfig(t)) + config := makeConfig(t) + config.EnableV2 = true + c := newTestComponents(t, config) ctx := context.Background() req := &pbv2.GetChunksRequest{ BlobKey: []byte{0}, } _, err := c.server.GetChunks(ctx, req) - require.Error(t, err) - s, ok := status.FromError(err) - require.True(t, ok) - assert.Equal(t, s.Code(), codes.InvalidArgument) + requireErrorStatus(t, err, codes.InvalidArgument) bk := [32]byte{0} maxUInt32 := uint32(0xFFFFFFFF) @@ -313,10 +315,14 @@ func TestV2GetChunksInputValidation(t *testing.T) { QuorumId: maxUInt32, } _, err = c.server.GetChunks(ctx, req) + requireErrorStatus(t, err, codes.InvalidArgument) +} + +func requireErrorStatus(t *testing.T, err error, code codes.Code) { require.Error(t, err) - s, ok = status.FromError(err) + s, ok := status.FromError(err) require.True(t, ok) - assert.Equal(t, s.Code(), codes.InvalidArgument) + assert.Equal(t, s.Code(), code) } type mockKey struct{} diff --git a/node/node.go b/node/node.go index a54b76ac33..d22058c7a5 100644 --- a/node/node.go +++ b/node/node.go @@ -16,6 +16,7 @@ import ( "sync" "time" + "github.com/Layr-Labs/eigenda/common/kvstore/tablestore" "github.com/Layr-Labs/eigenda/common/pubip" "github.com/Layr-Labs/eigenda/encoding/kzg/verifier" @@ -216,7 +217,25 @@ func NewNode( "eigenDAServiceManagerAddr", config.EigenDAServiceManagerAddr, "blockStaleMeasure", blockStaleMeasure, "storeDurationBlocks", storeDurationBlocks, "enableGnarkBundleEncoding", config.EnableGnarkBundleEncoding) var relayClient clients.RelayClient - // Create a new relay client with relay addresses onchain + var storeV2 StoreV2 + if config.EnableV2 { + v2Path := config.DbPath + "/chunk_v2" + dbV2, err := tablestore.Start(logger, &tablestore.Config{ + Type: tablestore.LevelDB, + Path: &v2Path, + GarbageCollectionEnabled: true, + GarbageCollectionInterval: time.Duration(config.ExpirationPollIntervalSec) * time.Second, + GarbageCollectionBatchSize: 1024, + Schema: []string{BatchHeaderTableName, BlobCertificateTableName, BundleTableName}, + }) + if err != nil { + return nil, fmt.Errorf("failed to create new tablestore: %w", err) + } + storeV2 = NewLevelDBStoreV2(dbV2, logger) + + // TODO(ian-shim): Create a new relay client with relay addresses onchain + } + return &Node{ Config: config, Logger: nodeLogger, @@ -224,7 +243,7 @@ func NewNode( Metrics: metrics, NodeApi: nodeApi, Store: store, - StoreV2: nil, + StoreV2: storeV2, ChainState: cst, Transactor: tx, Validator: validator, diff --git a/node/store_v2.go b/node/store_v2.go index c5979e72fe..62da00f54d 100644 --- a/node/store_v2.go +++ b/node/store_v2.go @@ -41,6 +41,13 @@ func NewLevelDBStoreV2(db kvstore.TableStore, logger logging.Logger) *storeV2 { } func (s *storeV2) StoreBatch(batch *corev2.Batch, rawBundles []*RawBundles) ([]kvstore.Key, error) { + if len(rawBundles) == 0 { + return nil, fmt.Errorf("no raw bundles") + } + if len(rawBundles) != len(batch.BlobCertificates) { + return nil, fmt.Errorf("mismatch between raw bundles (%d) and blob certificates (%d)", len(rawBundles), len(batch.BlobCertificates)) + } + dbBatch := s.db.NewTTLBatch() keys := make([]kvstore.Key, 0) diff --git a/relay/metadata_provider.go b/relay/metadata_provider.go index 3e32924072..33407fa124 100644 --- a/relay/metadata_provider.go +++ b/relay/metadata_provider.go @@ -3,12 +3,13 @@ package relay import ( "context" "fmt" - "github.com/Layr-Labs/eigenda/core/v2" + "sync/atomic" + + v2 "github.com/Layr-Labs/eigenda/core/v2" "github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore" "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/relay/cache" "github.com/Layr-Labs/eigensdk-go/logging" - "sync/atomic" ) // Metadata about a blob. The relay only needs a small subset of a blob's metadata. @@ -79,8 +80,10 @@ func newMetadataProvider( type metadataMap map[v2.BlobKey]*blobMetadata // GetMetadataForBlobs retrieves metadata about multiple blobs in parallel. +// If any of the blobs do not exist, an error is returned. +// Note that resulting metadata map may not have the same length as the input +// keys slice if the input keys slice has duplicate items. func (m *metadataProvider) GetMetadataForBlobs(keys []v2.BlobKey) (metadataMap, error) { - // blobMetadataResult is the result of a metadata fetch operation. type blobMetadataResult struct { key v2.BlobKey @@ -94,7 +97,12 @@ func (m *metadataProvider) GetMetadataForBlobs(keys []v2.BlobKey) (metadataMap, // Set when the first error is encountered. Useful for preventing new operations from starting. hadError := atomic.Bool{} + mMap := make(metadataMap) for _, key := range keys { + mMap[key] = nil + } + + for key := range mMap { if hadError.Load() { // Don't bother starting new operations if we've already encountered an error. break @@ -122,8 +130,7 @@ func (m *metadataProvider) GetMetadataForBlobs(keys []v2.BlobKey) (metadataMap, }() } - mMap := make(metadataMap) - for len(mMap) < len(keys) { + for range mMap { result := <-completionChannel if result.err != nil { return nil, fmt.Errorf("error fetching metadata for blob %s: %w", result.key.Hex(), result.err) diff --git a/relay/metadata_provider_test.go b/relay/metadata_provider_test.go index e5586d901b..32e5a3e80c 100644 --- a/relay/metadata_provider_test.go +++ b/relay/metadata_provider_test.go @@ -2,14 +2,15 @@ package relay import ( "context" + "math/rand" + "testing" + "github.com/Layr-Labs/eigenda/common" tu "github.com/Layr-Labs/eigenda/common/testutils" - "github.com/Layr-Labs/eigenda/core/v2" + v2 "github.com/Layr-Labs/eigenda/core/v2" "github.com/Layr-Labs/eigenda/encoding" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "math/rand" - "testing" ) func TestGetNonExistentBlob(t *testing.T) { @@ -121,10 +122,12 @@ func TestBatchedFetch(t *testing.T) { // Write some metadata blobCount := 10 + blobKeys := make([]v2.BlobKey, blobCount) for i := 0; i < blobCount; i++ { header, _ := randomBlob(t) blobKey, err := header.BlobKey() require.NoError(t, err) + blobKeys[i] = blobKey totalChunkSizeBytes := uint32(rand.Intn(1024 * 1024 * 1024)) fragmentSizeBytes := uint32(rand.Intn(1024 * 1024)) @@ -179,6 +182,11 @@ func TestBatchedFetch(t *testing.T) { require.Equal(t, fragmentSizeMap[key], metadata.fragmentSizeBytes) } } + + // Test fetching with duplicate keys + mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKeys[0], blobKeys[0]}) + require.NoError(t, err) + require.Equal(t, 1, len(mMap)) } func TestIndividualFetchWithSharding(t *testing.T) { diff --git a/relay/server.go b/relay/server.go index ad6072b9fe..56bedea146 100644 --- a/relay/server.go +++ b/relay/server.go @@ -4,6 +4,9 @@ import ( "context" "errors" "fmt" + "net" + "time" + pb "github.com/Layr-Labs/eigenda/api/grpc/relay" "github.com/Layr-Labs/eigenda/common/healthcheck" "github.com/Layr-Labs/eigenda/core" @@ -15,8 +18,6 @@ import ( "github.com/Layr-Labs/eigensdk-go/logging" "google.golang.org/grpc" "google.golang.org/grpc/reflection" - "net" - "time" ) var _ pb.RelayServer = &Server{} @@ -208,6 +209,7 @@ func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (* } defer s.chunkRateLimiter.FinishGetChunkOperation(clientID) + // keys might contain duplicate keys keys, err := getKeysFromChunkRequest(request) if err != nil { return nil, err @@ -273,7 +275,7 @@ func gatherChunkDataToSend( frames map[v2.BlobKey][]*encoding.Frame, request *pb.GetChunksRequest) ([][]byte, error) { - bytesToSend := make([][]byte, 0, len(frames)) + bytesToSend := make([][]byte, 0, len(request.ChunkRequests)) for _, chunkRequest := range request.ChunkRequests { diff --git a/relay/server_test.go b/relay/server_test.go index cedfa6ddb4..10a7ecca91 100644 --- a/relay/server_test.go +++ b/relay/server_test.go @@ -2,10 +2,11 @@ package relay import ( "context" - "github.com/Layr-Labs/eigenda/relay/limiter" "math/rand" "testing" + "github.com/Layr-Labs/eigenda/relay/limiter" + pb "github.com/Layr-Labs/eigenda/api/grpc/relay" "github.com/Layr-Labs/eigenda/common" tu "github.com/Layr-Labs/eigenda/common/testutils" @@ -1015,6 +1016,16 @@ func TestBatchedReadWriteChunksWithSharding(t *testing.T) { requestedChunks = append(requestedChunks, request) } + // Add a request for duplicate key with different index range + requestedChunks = append(requestedChunks, &pb.ChunkRequest{ + Request: &pb.ChunkRequest_ByRange{ + ByRange: &pb.ChunkRequestByRange{ + BlobKey: keys[0][:], + StartIndex: uint32(len(expectedData[keys[0]]) / 2), + EndIndex: uint32(len(expectedData[keys[0]])), + }, + }, + }) request := &pb.GetChunksRequest{ ChunkRequests: requestedChunks, } @@ -1036,11 +1047,10 @@ func TestBatchedReadWriteChunksWithSharding(t *testing.T) { } response, err := getChunks(t, request) - if allInCorrectShard { require.NoError(t, err) - require.Equal(t, keyCount, len(response.Data)) + require.Equal(t, keyCount+1, len(response.Data)) for keyIndex, key := range keys { data := expectedData[key] @@ -1052,6 +1062,17 @@ func TestBatchedReadWriteChunksWithSharding(t *testing.T) { require.Equal(t, data[frameIndex], frame) } } + + // Check the duplicate key + key := keys[0] + data := expectedData[key][len(expectedData[key])/2:] + + bundle, err := core.Bundle{}.Deserialize(response.Data[keyCount]) + require.NoError(t, err) + + for frameIndex, frame := range bundle { + require.Equal(t, data[frameIndex], frame) + } } else { require.Error(t, err) require.Nil(t, response)