Skip to content

Commit

Permalink
Changed get-latest protocol to have responses with a status code.
Browse files Browse the repository at this point in the history
Wrapped protocol read/writes in contexts.
A little refactoring.
Protobuf Makefile builds on Windows too.
  • Loading branch information
aschmahmann committed Aug 9, 2019
1 parent e9b0864 commit f85f2bc
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 92 deletions.
108 changes: 74 additions & 34 deletions getlatest.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package namesys

import (
"bufio"
"context"
"errors"
"io"
"time"

Expand All @@ -17,27 +17,33 @@ import (
pb "github.com/libp2p/go-libp2p-pubsub-router/pb"
)

var GetLatestErr = errors.New("get-latest: received error")

type getLatestProtocol struct {
ctx context.Context
host host.Host
}

func newGetLatestProtocol(host host.Host, getLocal func(key string) ([]byte, error)) *getLatestProtocol {
p := &getLatestProtocol{host}
func newGetLatestProtocol(ctx context.Context, host host.Host, getLocal func(key string) ([]byte, error)) *getLatestProtocol {
p := &getLatestProtocol{ctx, host}

host.SetStreamHandler(PSGetLatestProto, func(s network.Stream) {
p.Receive(s, getLocal)
p.receive(s, getLocal)
})

return p
}

func (p *getLatestProtocol) Receive(s network.Stream, getLocal func(key string) ([]byte, error)) {
r := ggio.NewDelimitedReader(s, 1<<20)
func (p *getLatestProtocol) receive(s network.Stream, getLocal func(key string) ([]byte, error)) {
msg := &pb.RequestLatest{}
if err := r.ReadMsg(msg); err != nil {
if err := readMsg(p.ctx, s, msg); err != nil {
if err != io.EOF {
s.Reset()
log.Infof("error reading request from %s: %s", s.Conn().RemotePeer(), err)
respProto := pb.RespondLatest{Status: pb.RespondLatest_ERR}
if err := writeMsg(p.ctx, s, &respProto); err != nil {
return
}
helpers.FullClose(s)
} else {
// Just be nice. They probably won't read this
// but it doesn't hurt to send it.
Expand All @@ -46,64 +52,98 @@ func (p *getLatestProtocol) Receive(s network.Stream, getLocal func(key string)
return
}

response, err := getLocal(*msg.Identifier)
response, err := getLocal(msg.Identifier)
var respProto pb.RespondLatest

if err != nil || response == nil {
nodata := true
respProto = pb.RespondLatest{Nodata: &nodata}
if err != nil {
respProto = pb.RespondLatest{Status: pb.RespondLatest_NOT_FOUND}
} else {
respProto = pb.RespondLatest{Data: response}
}

if err := writeBytes(s, &respProto); err != nil {
s.Reset()
log.Infof("error writing response to %s: %s", s.Conn().RemotePeer(), err)
if err := writeMsg(p.ctx, s, &respProto); err != nil {
return
}
helpers.FullClose(s)
}

func (p getLatestProtocol) Send(ctx context.Context, pid peer.ID, key string) ([]byte, error) {
func (p getLatestProtocol) Get(ctx context.Context, pid peer.ID, key string) ([]byte, error) {
peerCtx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()

s, err := p.host.NewStream(peerCtx, pid, PSGetLatestProto)
if err != nil {
return nil, err
}

if err := s.SetDeadline(time.Now().Add(time.Second * 5)); err != nil {
return nil, err
}

defer helpers.FullClose(s)

msg := pb.RequestLatest{Identifier: &key}
msg := &pb.RequestLatest{Identifier: key}

if err := writeBytes(s, &msg); err != nil {
s.Reset()
if err := writeMsg(ctx, s, msg); err != nil {
return nil, err
}

s.Close()

r := ggio.NewDelimitedReader(s, 1<<20)
response := &pb.RespondLatest{}
if err := r.ReadMsg(response); err != nil {
if err := readMsg(ctx, s, response); err != nil {
return nil, err
}

return response.Data, nil
switch response.Status {
case pb.RespondLatest_SUCCESS:
return response.Data, nil
case pb.RespondLatest_NOT_FOUND:
return nil, nil
case pb.RespondLatest_ERR:
return nil, GetLatestErr
default:
return nil, errors.New("get-latest: received unknown status code")
}
}

func writeBytes(w io.Writer, msg proto.Message) error {
bufw := bufio.NewWriter(w)
wc := ggio.NewDelimitedWriter(bufw)
func writeMsg(ctx context.Context, s network.Stream, msg proto.Message) error {
done := make(chan error)
go func() {
wc := ggio.NewDelimitedWriter(s)

if err := wc.WriteMsg(msg); err != nil {
return err
if err := wc.WriteMsg(msg); err != nil {
done <- err
return
}

done <- nil
}()

var retErr error
select {
case retErr = <-done:
case <-ctx.Done():
retErr = ctx.Err()
}

if retErr != nil {
s.Reset()
log.Infof("error writing response to %s: %s", s.Conn().RemotePeer(), retErr)
}
return retErr
}

return bufw.Flush()
func readMsg(ctx context.Context, s network.Stream, msg proto.Message) error {
done := make(chan error)
go func() {
r := ggio.NewDelimitedReader(s, 1<<20)
if err := r.ReadMsg(msg); err != nil {
done <- err
return
}
done <- nil
}()

select {
case err := <-done:
return err
case <-ctx.Done():
s.Reset()
return ctx.Err()
}
}
56 changes: 48 additions & 8 deletions getlatest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package namesys
import (
"bytes"
"context"
"encoding/binary"
"errors"
"testing"
"time"

"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/host"

pb "github.com/libp2p/go-libp2p-pubsub-router/pb"
)

func connect(t *testing.T, a, b host.Host) {
Expand Down Expand Up @@ -41,16 +45,16 @@ func TestGetLatestProtocolTrip(t *testing.T) {
time.Sleep(time.Millisecond * 100)

d1 := &datastore{map[string][]byte{"key": []byte("value1")}}
h1 := newGetLatestProtocol(hosts[0], d1.Lookup)
h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup)

d2 := &datastore{map[string][]byte{"key": []byte("value2")}}
h2 := newGetLatestProtocol(hosts[1], d2.Lookup)
h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup)

getLatest(t, ctx, h1, h2, "key", []byte("value2"))
getLatest(t, ctx, h2, h1, "key", []byte("value1"))
}

func TestGetLatestProtocolNil(t *testing.T) {
func TestGetLatestProtocolNotFound(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand All @@ -61,15 +65,51 @@ func TestGetLatestProtocolNil(t *testing.T) {
time.Sleep(time.Millisecond * 100)

d1 := &datastore{map[string][]byte{"key": []byte("value1")}}
h1 := newGetLatestProtocol(hosts[0], d1.Lookup)
h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup)

d2 := &datastore{make(map[string][]byte)}
h2 := newGetLatestProtocol(hosts[1], d2.Lookup)
h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup)

getLatest(t, ctx, h1, h2, "key", nil)
getLatest(t, ctx, h2, h1, "key", []byte("value1"))
}

func TestGetLatestProtocolErr(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

hosts := newNetHosts(ctx, t, 2)
connect(t, hosts[0], hosts[1])

// wait for hosts to get connected
time.Sleep(time.Millisecond * 100)

d1 := &datastore{make(map[string][]byte)}
h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup)

// bad send protocol to force an error
s, err := hosts[1].NewStream(ctx, h1.host.ID(), PSGetLatestProto)
if err != nil {
t.Fatal(err)
}
defer helpers.FullClose(s)

buf := make([]byte, binary.MaxVarintLen64)
binary.PutUvarint(buf, ^uint64(0))
if _, err := s.Write(buf); err != nil {
t.Fatal(err)
}

response := &pb.RespondLatest{}
if err := readMsg(ctx, s, response); err != nil {
t.Fatal(err)
}

if response.Status != pb.RespondLatest_ERR {
t.Fatal("should have received an error")
}
}

func TestGetLatestProtocolRepeated(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -81,10 +121,10 @@ func TestGetLatestProtocolRepeated(t *testing.T) {
time.Sleep(time.Millisecond * 100)

d1 := &datastore{map[string][]byte{"key": []byte("value1")}}
h1 := newGetLatestProtocol(hosts[0], d1.Lookup)
h1 := newGetLatestProtocol(ctx, hosts[0], d1.Lookup)

d2 := &datastore{make(map[string][]byte)}
h2 := newGetLatestProtocol(hosts[1], d2.Lookup)
h2 := newGetLatestProtocol(ctx, hosts[1], d2.Lookup)

for i := 0; i < 10; i++ {
getLatest(t, ctx, h1, h2, "key", nil)
Expand All @@ -94,7 +134,7 @@ func TestGetLatestProtocolRepeated(t *testing.T) {

func getLatest(t *testing.T, ctx context.Context,
requester *getLatestProtocol, responder *getLatestProtocol, key string, expected []byte) {
data, err := requester.Send(ctx, responder.host.ID(), key)
data, err := requester.Get(ctx, responder.host.ID(), key)
if err != nil {
t.Fatal(err)
}
Expand Down
8 changes: 7 additions & 1 deletion pb/Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
PB = $(wildcard *.proto)
GO = $(PB:.proto=.pb.go)

ifeq ($(OS),Windows_NT)
GOPATH_DELIMITER = \;

This comment has been minimized.

Copy link
@djdv

djdv Aug 9, 2019

I don't believe the \ has any effect in this scenario (on Windows)

else
GOPATH_DELIMITER = :
endif

all: $(GO)

%.pb.go: %.proto
protoc --proto_path=$(GOPATH)/src:. --gogofast_out=. $<
protoc --proto_path=$(GOPATH)/src$(GOPATH_DELIMITER). --gogofast_out=. $<

This comment has been minimized.

Copy link
@djdv

djdv Aug 9, 2019

$(GOPATH)/src$(GOPATH_DELIMITER). -> "$(GOPATH)/src$(GOPATH_DELIMITER)."


clean:
rm -f *.pb.go
Expand Down
Loading

0 comments on commit f85f2bc

Please sign in to comment.