Skip to content

Commit

Permalink
make tests pass on base package.
Browse files Browse the repository at this point in the history
  • Loading branch information
fiatjaf committed May 1, 2023
1 parent a4512da commit e84f5df
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 74 deletions.
10 changes: 8 additions & 2 deletions start.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ type Server struct {
// outputting to stderr.
Log Logger

addr string
relay Relay

// keep a connection reference to all connected clients for Server.Shutdown
clientsMu sync.Mutex
clients map[*websocket.Conn]struct{}

// in case you call Server.Start
Addr string
httpServer *http.Server
}

Expand Down Expand Up @@ -83,13 +83,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

func (s *Server) Start(host string, port int) error {
func (s *Server) Start(host string, port int, started ...chan bool) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}

s.Addr = ln.Addr().String()
s.httpServer = &http.Server{
Handler: cors.Default().Handler(s),
Addr: addr,
Expand All @@ -98,6 +99,11 @@ func (s *Server) Start(host string, port int) error {
IdleTimeout: 30 * time.Second,
}

// notify caller that we're starting
for _, started := range started {
close(started)
}

if err := s.httpServer.Serve(ln); err == http.ErrServerClosed {
return nil
} else if err != nil {
Expand Down
43 changes: 13 additions & 30 deletions start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,28 @@ import (

func TestServerStartShutdown(t *testing.T) {
var (
serverHost string
inited bool
storeInited bool
shutdown bool
)
ready := make(chan struct{})
rl := &testRelay{
name: "test server start",
init: func() error {
inited = true
return nil
},
onInitialized: func(s *Server) {
serverHost = s.Addr()
close(ready)
},
onShutdown: func(context.Context) { shutdown = true },
storage: &testStorage{
init: func() error { storeInited = true; return nil },
},
}
srv := NewServer("127.0.0.1:0", rl)
srv, _ := NewServer(rl)
ready := make(chan bool)
done := make(chan error)
go func() { done <- srv.Start(); close(done) }()
go func() { done <- srv.Start("127.0.0.1", 0, ready); close(done) }()
<-ready

// verify everything's initialized
select {
case <-ready:
// continue
case <-time.After(time.Second):
t.Fatal("srv.Start too long to initialize")
}
if !inited {
t.Error("didn't call testRelay.init")
}
Expand All @@ -52,16 +42,14 @@ func TestServerStartShutdown(t *testing.T) {
}

// check that http requests are served
if _, err := http.Get("http://" + serverHost); err != nil {
t.Errorf("GET %s: %v", serverHost, err)
if _, err := http.Get("http://" + srv.Addr); err != nil {
t.Errorf("GET %s: %v", srv.Addr, err)
}

// verify server shuts down
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Errorf("srv.Shutdown: %v", err)
}
srv.Shutdown(ctx)
if !shutdown {
t.Error("didn't call testRelay.onShutdown")
}
Expand All @@ -82,25 +70,20 @@ func TestServerShutdownWebsocket(t *testing.T) {
// connect a client to it
ctx1, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr())
client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr)
if err != nil {
t.Fatalf("nostr.RelayConnectContext: %v", err)
}

// now, shut down the server
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := srv.Shutdown(ctx2); err != nil {
t.Errorf("srv.Shutdown: %v", err)
}
srv.Shutdown(ctx2)

// wait for the client to receive a "connection close"
select {
case err := <-client.ConnectionError:
if _, ok := err.(*websocket.CloseError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
}
case <-time.After(2 * time.Second):
t.Error("client took too long to disconnect")
time.Sleep(1 * time.Second)
err = client.ConnectionError
if _, ok := err.(*websocket.CloseError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
}
}
63 changes: 21 additions & 42 deletions util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,29 @@ package relayer
import (
"context"
"testing"
"time"

"github.com/nbd-wtf/go-nostr"
)

func startTestRelay(t *testing.T, tr *testRelay) *Server {
t.Helper()
ready := make(chan struct{})

onInitializedFn := tr.onInitialized
tr.onInitialized = func(s *Server) {
close(ready)
if onInitializedFn != nil {
onInitializedFn(s)
}
}
srv := NewServer("127.0.0.1:0", tr)
go srv.Start()

select {
case <-ready:
case <-time.After(time.Second):
t.Fatal("server took too long to start up")
}
srv, _ := NewServer(tr)
started := make(chan bool)
go srv.Start("127.0.0.1", 0, started)
<-started
return srv
}

type testRelay struct {
name string
storage Storage
init func() error
onInitialized func(*Server)
onShutdown func(context.Context)
acceptEvent func(*nostr.Event) bool
name string
storage Storage
init func() error
onShutdown func(context.Context)
acceptEvent func(*nostr.Event) bool
}

func (tr *testRelay) Name() string { return tr.name }
func (tr *testRelay) Storage() Storage { return tr.storage }
func (tr *testRelay) Name() string { return tr.name }
func (tr *testRelay) Storage(context.Context) Storage { return tr.storage }

func (tr *testRelay) Init() error {
if fn := tr.init; fn != nil {
Expand All @@ -49,19 +34,13 @@ func (tr *testRelay) Init() error {
return nil
}

func (tr *testRelay) OnInitialized(s *Server) {
if fn := tr.onInitialized; fn != nil {
fn(s)
}
}

func (tr *testRelay) OnShutdown(ctx context.Context) {
if fn := tr.onShutdown; fn != nil {
fn(ctx)
}
}

func (tr *testRelay) AcceptEvent(e *nostr.Event) bool {
func (tr *testRelay) AcceptEvent(ctx context.Context, e *nostr.Event) bool {
if fn := tr.acceptEvent; fn != nil {
return fn(e)
}
Expand All @@ -70,9 +49,9 @@ func (tr *testRelay) AcceptEvent(e *nostr.Event) bool {

type testStorage struct {
init func() error
queryEvents func(*nostr.Filter) ([]nostr.Event, error)
deleteEvent func(id string, pubkey string) error
saveEvent func(*nostr.Event) error
queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error)
deleteEvent func(ctx context.Context, id string, pubkey string) error
saveEvent func(context.Context, *nostr.Event) error
}

func (st *testStorage) Init() error {
Expand All @@ -82,23 +61,23 @@ func (st *testStorage) Init() error {
return nil
}

func (st *testStorage) QueryEvents(f *nostr.Filter) ([]nostr.Event, error) {
func (st *testStorage) QueryEvents(ctx context.Context, f *nostr.Filter) (chan *nostr.Event, error) {
if fn := st.queryEvents; fn != nil {
return fn(f)
return fn(ctx, f)
}
return nil, nil
}

func (st *testStorage) DeleteEvent(id string, pubkey string) error {
func (st *testStorage) DeleteEvent(ctx context.Context, id string, pubkey string) error {
if fn := st.deleteEvent; fn != nil {
return fn(id, pubkey)
return fn(ctx, id, pubkey)
}
return nil
}

func (st *testStorage) SaveEvent(e *nostr.Event) error {
func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error {
if fn := st.saveEvent; fn != nil {
return fn(e)
return fn(ctx, e)
}
return nil
}

0 comments on commit e84f5df

Please sign in to comment.