Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 6 additions & 62 deletions assert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,19 @@ package websocket_test

import (
"context"
"fmt"
"math/rand"
"reflect"
"strings"
"time"

"github.com/google/go-cmp/cmp"

"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
"nhooyr.io/websocket/wsjson"
)

func init() {
rand.Seed(time.Now().UnixNano())
}

// https://github.com/google/go-cmp/issues/40#issuecomment-328615283
func cmpDiff(exp, act interface{}) string {
return cmp.Diff(exp, act, deepAllowUnexported(exp, act))
}

func deepAllowUnexported(vs ...interface{}) cmp.Option {
m := make(map[reflect.Type]struct{})
for _, v := range vs {
structTypes(reflect.ValueOf(v), m)
}
var typs []interface{}
for t := range m {
typs = append(typs, reflect.New(t).Elem().Interface())
}
return cmp.AllowUnexported(typs...)
}

func structTypes(v reflect.Value, m map[reflect.Type]struct{}) {
if !v.IsValid() {
return
}
switch v.Kind() {
case reflect.Ptr:
if !v.IsNil() {
structTypes(v.Elem(), m)
}
case reflect.Interface:
if !v.IsNil() {
structTypes(v.Elem(), m)
}
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
structTypes(v.Index(i), m)
}
case reflect.Map:
for _, k := range v.MapKeys() {
structTypes(v.MapIndex(k), m)
}
case reflect.Struct:
m[v.Type()] = struct{}{}
for i := 0; i < v.NumField(); i++ {
structTypes(v.Field(i), m)
}
}
}

func assertEqualf(exp, act interface{}, f string, v ...interface{}) error {
if diff := cmpDiff(exp, act); diff != "" {
return fmt.Errorf(f+": %v", append(v, diff)...)
}
return nil
}

func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
exp := randString(n)
err := wsjson.Write(ctx, c, exp)
Expand All @@ -84,7 +28,7 @@ func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error {
return err
}

return assertEqualf(exp, act, "unexpected JSON")
return assert.Equalf(exp, act, "unexpected JSON")
}

func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
Expand All @@ -94,7 +38,7 @@ func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) err
return err
}

return assertEqualf(exp, act, "unexpected JSON")
return assert.Equalf(exp, act, "unexpected JSON")
}

func randBytes(n int) []byte {
Expand Down Expand Up @@ -126,13 +70,13 @@ func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageTyp
if err != nil {
return err
}
err = assertEqualf(typ, typ2, "unexpected data type")
err = assert.Equalf(typ, typ2, "unexpected data type")
if err != nil {
return err
}
return assertEqualf(p, p2, "unexpected payload")
return assert.Equalf(p, p2, "unexpected payload")
}

func assertSubprotocol(c *websocket.Conn, exp string) error {
return assertEqualf(exp, c.Subprotocol(), "unexpected subprotocol")
return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol")
}
39 changes: 20 additions & 19 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"go.uber.org/multierr"

"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
"nhooyr.io/websocket/internal/wsecho"
"nhooyr.io/websocket/wsjson"
"nhooyr.io/websocket/wspb"
Expand Down Expand Up @@ -127,7 +128,7 @@ func TestHandshake(t *testing.T) {
if err != nil {
return fmt.Errorf("request is missing mycookie: %w", err)
}
err = assertEqualf("myvalue", cookie.Value, "unexpected cookie value")
err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value")
if err != nil {
return err
}
Expand Down Expand Up @@ -219,7 +220,7 @@ func TestConn(t *testing.T) {
}
for h, exp := range headers {
value := resp.Header.Get(h)
err := assertEqualf(exp, value, "unexpected value for header %v", h)
err := assert.Equalf(exp, value, "unexpected value for header %v", h)
if err != nil {
return err
}
Expand Down Expand Up @@ -276,11 +277,11 @@ func TestConn(t *testing.T) {
time.Sleep(1)
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))

err := assertEqualf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr")
if err != nil {
return err
}
err = assertEqualf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr")
if err != nil {
return err
}
Expand Down Expand Up @@ -310,13 +311,13 @@ func TestConn(t *testing.T) {

// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
err2 := assertNetConnRead(nc, "hello")
err := assertEqualf(io.EOF, err2, "unexpected error")
err := assert.Equalf(io.EOF, err2, "unexpected error")
if err != nil {
return err
}

err2 = assertNetConnRead(nc, "hello")
return assertEqualf(io.EOF, err2, "unexpected error")
return assert.Equalf(io.EOF, err2, "unexpected error")
},
},
{
Expand Down Expand Up @@ -772,15 +773,15 @@ func TestConn(t *testing.T) {
if err != nil {
return err
}
err = assertEqualf("hi", v, "unexpected JSON")
err = assert.Equalf("hi", v, "unexpected JSON")
if err != nil {
return err
}
_, b, err := c.Read(ctx)
if err != nil {
return err
}
return assertEqualf("hi", string(b), "unexpected JSON")
return assert.Equalf("hi", string(b), "unexpected JSON")
},
client: func(ctx context.Context, c *websocket.Conn) error {
err := wsjson.Write(ctx, c, "hi")
Expand Down Expand Up @@ -1079,11 +1080,11 @@ func TestAutobahn(t *testing.T) {
if err != nil {
return err
}
err = assertEqualf(typ, actTyp, "unexpected message type")
err = assert.Equalf(typ, actTyp, "unexpected message type")
if err != nil {
return err
}
return assertEqualf(p, p2, "unexpected message")
return assert.Equalf(p, p2, "unexpected message")
})
}
}
Expand Down Expand Up @@ -1859,7 +1860,7 @@ func assertCloseStatus(err error, code websocket.StatusCode) error {
if !errors.As(err, &cerr) {
return fmt.Errorf("no websocket close error in error chain: %+v", err)
}
return assertEqualf(code, cerr.Code, "unexpected status code")
return assert.Equalf(code, cerr.Code, "unexpected status code")
}

func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error {
Expand All @@ -1871,7 +1872,7 @@ func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{})
return err
}

return assertEqualf(exp, act, "unexpected protobuf")
return assert.Equalf(exp, act, "unexpected protobuf")
}

func assertNetConnRead(r io.Reader, exp string) error {
Expand All @@ -1880,7 +1881,7 @@ func assertNetConnRead(r io.Reader, exp string) error {
if err != nil {
return err
}
return assertEqualf(exp, string(act), "unexpected net conn read")
return assert.Equalf(exp, string(act), "unexpected net conn read")
}

func assertErrorContains(err error, exp string) error {
Expand All @@ -1902,27 +1903,27 @@ func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.Op
if err != nil {
return err
}
err = assertEqualf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
return assertEqualf(p, actP, "unexpected frame %v payload", opcode)
return assert.Equalf(p, actP, "unexpected frame %v payload", opcode)
}

func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error {
actOpcode, actP, err := c.ReadFrame(ctx)
if err != nil {
return err
}
err = assertEqualf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
ce, err := websocket.ParseClosePayload(actP)
if err != nil {
return fmt.Errorf("failed to parse close frame payload: %w", err)
}
return assertEqualf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP)
}

func assertCloseHandshake(ctx context.Context, c *websocket.Conn, code websocket.StatusCode, reason string) error {
Expand Down Expand Up @@ -1960,11 +1961,11 @@ func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.Mes
if err != nil {
return err
}
err = assertEqualf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP)
if err != nil {
return err
}
return assertEqualf(p, actP, "unexpected frame %v payload", actTyp)
return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp)
}

func BenchmarkConn(b *testing.B) {
Expand Down
12 changes: 12 additions & 0 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package websocket

import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -252,6 +253,17 @@ func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}

// CloseStatus is a convenience wrapper around xerrors.As to grab
// the status code from a *CloseError. If the passed error is nil
// or not a *CloseError, the returned StatusCode will be -1.
func CloseStatus(err error) StatusCode {
var ce *CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}

Copy link
Contributor Author

@nhooyr nhooyr Oct 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main addition in this PR, the rest is just some refactoring to allow the use of assert.Equalf in the tests for this function.

func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Expand Down
42 changes: 42 additions & 0 deletions frame_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"time"

"github.com/google/go-cmp/cmp"

"nhooyr.io/websocket/internal/assert"
)

func init() {
Expand Down Expand Up @@ -376,3 +378,43 @@ func BenchmarkXOR(b *testing.B) {
})
}
}

func TestCloseStatus(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
in error
exp StatusCode
}{
{
name: "nil",
in: nil,
exp: -1,
},
{
name: "io.EOF",
in: io.EOF,
exp: -1,
},
{
name: "StatusInternalError",
in: &CloseError{
Code: StatusInternalError,
},
exp: StatusInternalError,
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status")
if err != nil {
t.Fatal(err)
}
})
}
}
Loading