diff --git a/getlatest.go b/getlatest.go index 935e959..673f357 100644 --- a/getlatest.go +++ b/getlatest.go @@ -1,8 +1,8 @@ package namesys import ( - "bufio" "context" + "errors" "io" "time" @@ -17,27 +17,33 @@ import ( pb "github.com/libp2p/go-libp2p-pubsub-router/pb" ) +var GetLatestErr = errors.New("get-latest: received error") + type getLatestProtocol struct { + ctx context.Context host host.Host } -func newGetLatestProtocol(host host.Host, getLocal func(key string) ([]byte, error)) *getLatestProtocol { - p := &getLatestProtocol{host} +func newGetLatestProtocol(ctx context.Context, host host.Host, getLocal func(key string) ([]byte, error)) *getLatestProtocol { + p := &getLatestProtocol{ctx, host} host.SetStreamHandler(PSGetLatestProto, func(s network.Stream) { - p.Receive(s, getLocal) + p.receive(s, getLocal) }) return p } -func (p *getLatestProtocol) Receive(s network.Stream, getLocal func(key string) ([]byte, error)) { - r := ggio.NewDelimitedReader(s, 1<<20) +func (p *getLatestProtocol) receive(s network.Stream, getLocal func(key string) ([]byte, error)) { msg := &pb.RequestLatest{} - if err := r.ReadMsg(msg); err != nil { + if err := readMsg(p.ctx, s, msg); err != nil { if err != io.EOF { - s.Reset() log.Infof("error reading request from %s: %s", s.Conn().RemotePeer(), err) + respProto := pb.RespondLatest{Status: pb.RespondLatest_ERR} + if err := writeMsg(p.ctx, s, &respProto); err != nil { + return + } + helpers.FullClose(s) } else { // Just be nice. They probably won't read this // but it doesn't hurt to send it. @@ -46,25 +52,22 @@ func (p *getLatestProtocol) Receive(s network.Stream, getLocal func(key string) return } - response, err := getLocal(*msg.Identifier) + response, err := getLocal(msg.Identifier) var respProto pb.RespondLatest - if err != nil || response == nil { - nodata := true - respProto = pb.RespondLatest{Nodata: &nodata} + if err != nil { + respProto = pb.RespondLatest{Status: pb.RespondLatest_NOT_FOUND} } else { respProto = pb.RespondLatest{Data: response} } - if err := writeBytes(s, &respProto); err != nil { - s.Reset() - log.Infof("error writing response to %s: %s", s.Conn().RemotePeer(), err) + if err := writeMsg(p.ctx, s, &respProto); err != nil { return } helpers.FullClose(s) } -func (p getLatestProtocol) Send(ctx context.Context, pid peer.ID, key string) ([]byte, error) { +func (p getLatestProtocol) Get(ctx context.Context, pid peer.ID, key string) ([]byte, error) { peerCtx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() @@ -72,38 +75,75 @@ func (p getLatestProtocol) Send(ctx context.Context, pid peer.ID, key string) ([ if err != nil { return nil, err } - - if err := s.SetDeadline(time.Now().Add(time.Second * 5)); err != nil { - return nil, err - } - defer helpers.FullClose(s) - msg := pb.RequestLatest{Identifier: &key} + msg := &pb.RequestLatest{Identifier: key} - if err := writeBytes(s, &msg); err != nil { - s.Reset() + if err := writeMsg(ctx, s, msg); err != nil { return nil, err } - s.Close() - r := ggio.NewDelimitedReader(s, 1<<20) response := &pb.RespondLatest{} - if err := r.ReadMsg(response); err != nil { + if err := readMsg(ctx, s, response); err != nil { return nil, err } - return response.Data, nil + switch response.Status { + case pb.RespondLatest_SUCCESS: + return response.Data, nil + case pb.RespondLatest_NOT_FOUND: + return nil, nil + case pb.RespondLatest_ERR: + return nil, GetLatestErr + default: + return nil, errors.New("get-latest: received unknown status code") + } } -func writeBytes(w io.Writer, msg proto.Message) error { - bufw := bufio.NewWriter(w) - wc := ggio.NewDelimitedWriter(bufw) +func writeMsg(ctx context.Context, s network.Stream, msg proto.Message) error { + done := make(chan error) + go func() { + wc := ggio.NewDelimitedWriter(s) - if err := wc.WriteMsg(msg); err != nil { - return err + if err := wc.WriteMsg(msg); err != nil { + done <- err + return + } + + done <- nil + }() + + var retErr error + select { + case retErr = <-done: + case <-ctx.Done(): + retErr = ctx.Err() + } + + if retErr != nil { + s.Reset() + log.Infof("error writing response to %s: %s", s.Conn().RemotePeer(), retErr) } + return retErr +} - return bufw.Flush() +func readMsg(ctx context.Context, s network.Stream, msg proto.Message) error { + done := make(chan error) + go func() { + r := ggio.NewDelimitedReader(s, 1<<20) + if err := r.ReadMsg(msg); err != nil { + done <- err + return + } + done <- nil + }() + + select { + case err := <-done: + return err + case <-ctx.Done(): + s.Reset() + return ctx.Err() + } } diff --git a/getlatest_test.go b/getlatest_test.go index ea80c2b..41fe589 100644 --- a/getlatest_test.go +++ b/getlatest_test.go @@ -3,11 +3,15 @@ package namesys import ( "bytes" "context" + "encoding/binary" "errors" "testing" "time" + "github.com/libp2p/go-libp2p-core/helpers" "github.com/libp2p/go-libp2p-core/host" + + pb "github.com/libp2p/go-libp2p-pubsub-router/pb" ) func connect(t *testing.T, a, b host.Host) { @@ -41,16 +45,16 @@ func TestGetLatestProtocolTrip(t *testing.T) { time.Sleep(time.Millisecond * 100) d1 := &datastore{map[string][]byte{"key": []byte("value1")}} - h1 := newGetLatestProtocol(hosts[0], d1.Lookup) + h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup) d2 := &datastore{map[string][]byte{"key": []byte("value2")}} - h2 := newGetLatestProtocol(hosts[1], d2.Lookup) + h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup) getLatest(t, ctx, h1, h2, "key", []byte("value2")) getLatest(t, ctx, h2, h1, "key", []byte("value1")) } -func TestGetLatestProtocolNil(t *testing.T) { +func TestGetLatestProtocolNotFound(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -61,15 +65,51 @@ func TestGetLatestProtocolNil(t *testing.T) { time.Sleep(time.Millisecond * 100) d1 := &datastore{map[string][]byte{"key": []byte("value1")}} - h1 := newGetLatestProtocol(hosts[0], d1.Lookup) + h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup) d2 := &datastore{make(map[string][]byte)} - h2 := newGetLatestProtocol(hosts[1], d2.Lookup) + h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup) getLatest(t, ctx, h1, h2, "key", nil) getLatest(t, ctx, h2, h1, "key", []byte("value1")) } +func TestGetLatestProtocolErr(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := newNetHosts(ctx, t, 2) + connect(t, hosts[0], hosts[1]) + + // wait for hosts to get connected + time.Sleep(time.Millisecond * 100) + + d1 := &datastore{make(map[string][]byte)} + h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup) + + // bad send protocol to force an error + s, err := hosts[1].NewStream(ctx, h1.host.ID(), PSGetLatestProto) + if err != nil { + t.Fatal(err) + } + defer helpers.FullClose(s) + + buf := make([]byte, binary.MaxVarintLen64) + binary.PutUvarint(buf, ^uint64(0)) + if _, err := s.Write(buf); err != nil { + t.Fatal(err) + } + + response := &pb.RespondLatest{} + if err := readMsg(ctx, s, response); err != nil { + t.Fatal(err) + } + + if response.Status != pb.RespondLatest_ERR { + t.Fatal("should have received an error") + } +} + func TestGetLatestProtocolRepeated(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -81,10 +121,10 @@ func TestGetLatestProtocolRepeated(t *testing.T) { time.Sleep(time.Millisecond * 100) d1 := &datastore{map[string][]byte{"key": []byte("value1")}} - h1 := newGetLatestProtocol(hosts[0], d1.Lookup) + h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup) d2 := &datastore{make(map[string][]byte)} - h2 := newGetLatestProtocol(hosts[1], d2.Lookup) + h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup) for i := 0; i < 10; i++ { getLatest(t, ctx, h1, h2, "key", nil) @@ -94,7 +134,7 @@ func TestGetLatestProtocolRepeated(t *testing.T) { func getLatest(t *testing.T, ctx context.Context, requester *getLatestProtocol, responder *getLatestProtocol, key string, expected []byte) { - data, err := requester.Send(ctx, responder.host.ID(), key) + data, err := requester.Get(ctx, responder.host.ID(), key) if err != nil { t.Fatal(err) } diff --git a/pb/Makefile b/pb/Makefile index eb14b57..0353d62 100644 --- a/pb/Makefile +++ b/pb/Makefile @@ -1,10 +1,16 @@ PB = $(wildcard *.proto) GO = $(PB:.proto=.pb.go) +ifeq ($(OS),Windows_NT) + GOPATH_DELIMITER = \; +else + GOPATH_DELIMITER = : +endif + all: $(GO) %.pb.go: %.proto - protoc --proto_path=$(GOPATH)/src:. --gogofast_out=. $< + protoc --proto_path=$(GOPATH)/src$(GOPATH_DELIMITER). --gogofast_out=. $< clean: rm -f *.pb.go diff --git a/pb/message.pb.go b/pb/message.pb.go index 91e88fe..f728d34 100644 --- a/pb/message.pb.go +++ b/pb/message.pb.go @@ -21,8 +21,36 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package +type RespondLatest_StatusCode int32 + +const ( + RespondLatest_SUCCESS RespondLatest_StatusCode = 0 + RespondLatest_NOT_FOUND RespondLatest_StatusCode = 1 + RespondLatest_ERR RespondLatest_StatusCode = 2 +) + +var RespondLatest_StatusCode_name = map[int32]string{ + 0: "SUCCESS", + 1: "NOT_FOUND", + 2: "ERR", +} + +var RespondLatest_StatusCode_value = map[string]int32{ + "SUCCESS": 0, + "NOT_FOUND": 1, + "ERR": 2, +} + +func (x RespondLatest_StatusCode) String() string { + return proto.EnumName(RespondLatest_StatusCode_name, int32(x)) +} + +func (RespondLatest_StatusCode) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_33c57e4bae7b9afd, []int{1, 0} +} + type RequestLatest struct { - Identifier *string `protobuf:"bytes,1,opt,name=identifier" json:"identifier,omitempty"` + Identifier string `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` @@ -62,18 +90,18 @@ func (m *RequestLatest) XXX_DiscardUnknown() { var xxx_messageInfo_RequestLatest proto.InternalMessageInfo func (m *RequestLatest) GetIdentifier() string { - if m != nil && m.Identifier != nil { - return *m.Identifier + if m != nil { + return m.Identifier } return "" } type RespondLatest struct { - Data []byte `protobuf:"bytes,1,opt,name=data" json:"data,omitempty"` - Nodata *bool `protobuf:"varint,2,opt,name=nodata" json:"nodata,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + Status RespondLatest_StatusCode `protobuf:"varint,2,opt,name=status,proto3,enum=namesys.pb.RespondLatest_StatusCode" json:"status,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *RespondLatest) Reset() { *m = RespondLatest{} } @@ -116,14 +144,15 @@ func (m *RespondLatest) GetData() []byte { return nil } -func (m *RespondLatest) GetNodata() bool { - if m != nil && m.Nodata != nil { - return *m.Nodata +func (m *RespondLatest) GetStatus() RespondLatest_StatusCode { + if m != nil { + return m.Status } - return false + return RespondLatest_SUCCESS } func init() { + proto.RegisterEnum("namesys.pb.RespondLatest_StatusCode", RespondLatest_StatusCode_name, RespondLatest_StatusCode_value) proto.RegisterType((*RequestLatest)(nil), "namesys.pb.RequestLatest") proto.RegisterType((*RespondLatest)(nil), "namesys.pb.RespondLatest") } @@ -131,17 +160,21 @@ func init() { func init() { proto.RegisterFile("message.proto", fileDescriptor_33c57e4bae7b9afd) } var fileDescriptor_33c57e4bae7b9afd = []byte{ - // 145 bytes of a gzipped FileDescriptorProto + // 214 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcd, 0x4d, 0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0xca, 0x4b, 0xcc, 0x4d, 0x2d, 0xae, 0x2c, 0xd6, 0x2b, 0x48, 0x52, 0xd2, 0xe7, 0xe2, 0x0d, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0xf1, 0x49, 0x2c, 0x49, 0x2d, 0x2e, 0x11, 0x92, 0xe3, 0xe2, 0xca, 0x4c, 0x49, 0xcd, 0x2b, 0xc9, - 0x4c, 0xcb, 0x4c, 0x2d, 0x92, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x42, 0x12, 0x51, 0xb2, 0x06, - 0x69, 0x28, 0x2e, 0xc8, 0xcf, 0x4b, 0x81, 0x6a, 0x10, 0xe2, 0x62, 0x49, 0x49, 0x2c, 0x49, 0x04, - 0x2b, 0xe5, 0x09, 0x02, 0xb3, 0x85, 0xc4, 0xb8, 0xd8, 0xf2, 0xf2, 0xc1, 0xa2, 0x4c, 0x0a, 0x8c, - 0x1a, 0x1c, 0x41, 0x50, 0x9e, 0x13, 0xcf, 0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, - 0x78, 0x24, 0xc7, 0x08, 0x08, 0x00, 0x00, 0xff, 0xff, 0x21, 0x3f, 0xdb, 0xfb, 0x97, 0x00, 0x00, - 0x00, + 0x4c, 0xcb, 0x4c, 0x2d, 0x92, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x42, 0x12, 0x51, 0x9a, 0xc2, + 0x08, 0xd2, 0x51, 0x5c, 0x90, 0x9f, 0x97, 0x02, 0xd5, 0x21, 0xc4, 0xc5, 0x92, 0x92, 0x58, 0x92, + 0x08, 0x56, 0xcb, 0x13, 0x04, 0x66, 0x0b, 0xd9, 0x70, 0xb1, 0x15, 0x97, 0x24, 0x96, 0x94, 0x16, + 0x4b, 0x30, 0x29, 0x30, 0x6a, 0xf0, 0x19, 0xa9, 0xe8, 0x21, 0xec, 0xd4, 0x43, 0xd1, 0xae, 0x17, + 0x0c, 0x56, 0xe7, 0x9c, 0x9f, 0x92, 0x1a, 0x04, 0xd5, 0xa3, 0x64, 0xc8, 0xc5, 0x85, 0x10, 0x15, + 0xe2, 0xe6, 0x62, 0x0f, 0x0e, 0x75, 0x76, 0x76, 0x0d, 0x0e, 0x16, 0x60, 0x10, 0xe2, 0xe5, 0xe2, + 0xf4, 0xf3, 0x0f, 0x89, 0x77, 0xf3, 0x0f, 0xf5, 0x73, 0x11, 0x60, 0x14, 0x62, 0xe7, 0x62, 0x76, + 0x0d, 0x0a, 0x12, 0x60, 0x72, 0xe2, 0x39, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6, 0x07, + 0x8f, 0xe4, 0x18, 0x93, 0xd8, 0xc0, 0x1e, 0x35, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x0c, 0x22, + 0x78, 0xb9, 0xf9, 0x00, 0x00, 0x00, } func (m *RequestLatest) Marshal() (dAtA []byte, err error) { @@ -159,11 +192,11 @@ func (m *RequestLatest) MarshalTo(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.Identifier != nil { + if len(m.Identifier) > 0 { dAtA[i] = 0xa i++ - i = encodeVarintMessage(dAtA, i, uint64(len(*m.Identifier))) - i += copy(dAtA[i:], *m.Identifier) + i = encodeVarintMessage(dAtA, i, uint64(len(m.Identifier))) + i += copy(dAtA[i:], m.Identifier) } if m.XXX_unrecognized != nil { i += copy(dAtA[i:], m.XXX_unrecognized) @@ -186,21 +219,16 @@ func (m *RespondLatest) MarshalTo(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.Data != nil { + if len(m.Data) > 0 { dAtA[i] = 0xa i++ i = encodeVarintMessage(dAtA, i, uint64(len(m.Data))) i += copy(dAtA[i:], m.Data) } - if m.Nodata != nil { + if m.Status != 0 { dAtA[i] = 0x10 i++ - if *m.Nodata { - dAtA[i] = 1 - } else { - dAtA[i] = 0 - } - i++ + i = encodeVarintMessage(dAtA, i, uint64(m.Status)) } if m.XXX_unrecognized != nil { i += copy(dAtA[i:], m.XXX_unrecognized) @@ -223,8 +251,8 @@ func (m *RequestLatest) Size() (n int) { } var l int _ = l - if m.Identifier != nil { - l = len(*m.Identifier) + l = len(m.Identifier) + if l > 0 { n += 1 + l + sovMessage(uint64(l)) } if m.XXX_unrecognized != nil { @@ -239,12 +267,12 @@ func (m *RespondLatest) Size() (n int) { } var l int _ = l - if m.Data != nil { - l = len(m.Data) + l = len(m.Data) + if l > 0 { n += 1 + l + sovMessage(uint64(l)) } - if m.Nodata != nil { - n += 2 + if m.Status != 0 { + n += 1 + sovMessage(uint64(m.Status)) } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) @@ -324,8 +352,7 @@ func (m *RequestLatest) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - s := string(dAtA[iNdEx:postIndex]) - m.Identifier = &s + m.Identifier = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex default: iNdEx = preIndex @@ -417,9 +444,9 @@ func (m *RespondLatest) Unmarshal(dAtA []byte) error { iNdEx = postIndex case 2: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Nodata", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Status", wireType) } - var v int + m.Status = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowMessage @@ -429,13 +456,11 @@ func (m *RespondLatest) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - v |= int(b&0x7F) << shift + m.Status |= RespondLatest_StatusCode(b&0x7F) << shift if b < 0x80 { break } } - b := bool(v != 0) - m.Nodata = &b default: iNdEx = preIndex skippy, err := skipMessage(dAtA[iNdEx:]) diff --git a/pb/message.proto b/pb/message.proto index 4a04228..1ec433d 100644 --- a/pb/message.proto +++ b/pb/message.proto @@ -1,12 +1,17 @@ -syntax = "proto2"; +syntax = "proto3"; package namesys.pb; message RequestLatest { - optional string identifier = 1; + string identifier = 1; } message RespondLatest { - optional bytes data = 1; - optional bool nodata = 2; + bytes data = 1; + StatusCode status = 2; + enum StatusCode { + SUCCESS = 0; + NOT_FOUND = 1; + ERR = 2; + } } \ No newline at end of file diff --git a/pubsub.go b/pubsub.go index 79571ac..c9c737a 100644 --- a/pubsub.go +++ b/pubsub.go @@ -89,7 +89,7 @@ func NewPubsubValueStore(ctx context.Context, host host.Host, cr routing.Content Validator: validator, } - psValueStore.getLatest = newGetLatestProtocol(host, psValueStore.getLocal) + psValueStore.getLatest = newGetLatestProtocol(ctx, host, psValueStore.getLocal) go psValueStore.rebroadcast(ctx) @@ -476,7 +476,7 @@ func (p *PubsubValueStore) handleNewPeer(ctx context.Context, sub *pubsub.Subscr } } - return p.getLatest.Send(ctx, pid, key) + return p.getLatest.Get(ctx, pid, key) } func (p *PubsubValueStore) notifyWatchers(key string, data []byte) {