Skip to content

Commit

Permalink
Use generic read function
Browse files Browse the repository at this point in the history
  • Loading branch information
martonp committed Oct 7, 2024
1 parent 22e99a5 commit bd18aaa
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 65 deletions.
148 changes: 84 additions & 64 deletions client/comms/wsconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ type WsCfg struct {

// AutoReconnect, if non-zero, will reconnect to the server after each
// interval of the amount of time specified.
AutoReconnect time.Duration
AutoReconnect *time.Duration

// The server's certificate.
Cert []byte
Expand Down Expand Up @@ -392,63 +392,89 @@ func (conn *wsConn) close() {
conn.ws.Close()
}

func (conn *wsConn) readRaw(ctx context.Context) {
func genericRead[T any](ctx context.Context,
read func() (T, error),
handleResp func(T),
handleErr func(error) bool, // return true if error should stop the loop
reconnect func(),
autoReconnect *time.Duration) {

var reconnectTimer <-chan time.Time
if autoReconnect != nil {
reconnectTimer = time.After(*autoReconnect)
}

type readResult struct {
msg T
err error
}

readMessage := func() chan *readResult {
ch := make(chan *readResult, 1)
go func() {
T, err := read()
ch <- &readResult{T, err}
}()
return ch
}

for {
// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()
select {
case result := <-readMessage():
if ctx.Err() != nil {
return
}
if result.err != nil {
if handleErr(result.err) {
return
}
continue
}

// Block until a message is received or an error occurs.
_, msgBytes, err := ws.ReadMessage()
// Drop the read error on context cancellation.
if ctx.Err() != nil {
handleResp(result.msg)
case <-reconnectTimer:
reconnect()
return
}
if err != nil {
conn.handleReadError(err)
case <-ctx.Done():
return
}
conn.cfg.RawHandler(msgBytes)
}
}

func (conn *wsConn) readRaw(ctx context.Context) {
read := func() ([]byte, error) {
_, b, err := conn.ws.ReadMessage()
return b, err
}

handleError := func(err error) bool {
conn.handleReadError(err)
return true
}

reconnect := func() {
conn.reconnectCh <- struct{}{}
}

genericRead(ctx, read, conn.cfg.RawHandler, handleError, reconnect, conn.cfg.AutoReconnect)
}

// read fetches and parses incoming messages for processing. This should be
// run as a goroutine. Increment the wg before calling read.
func (conn *wsConn) read(ctx context.Context) {
for {
read := func() (*msgjson.Message, error) {
msg := new(msgjson.Message)
err := conn.ws.ReadJSON(msg)
return msg, err
}

// Lock since conn.ws may be set by connect.
conn.wsMtx.Lock()
ws := conn.ws
conn.wsMtx.Unlock()

// The read itself does not require locking since only this goroutine
// uses read functions that are not safe for concurrent use.
err := ws.ReadJSON(msg)
// Drop the read error on context cancellation.
if ctx.Err() != nil {
return
}
if err != nil {
var mErr *json.UnmarshalTypeError
if errors.As(err, &mErr) {
// JSON decode errors are not fatal, log and proceed.
conn.log.Errorf("json decode error: %v", mErr)
continue
}
conn.handleReadError(err)
return
}

// If the message is a response, find the handler.
handleResp := func(msg *msgjson.Message) {
if msg.Type == msgjson.Response {
handler := conn.respHandler(msg.ID)
if handler == nil {
b, _ := json.Marshal(msg)
conn.log.Errorf("No handler found for response: %v", string(b))
continue
return
}
// Run handlers in a goroutine so that other messages can be
// received. Include the handler goroutines in the WaitGroup to
Expand All @@ -458,10 +484,27 @@ func (conn *wsConn) read(ctx context.Context) {
defer conn.wg.Done()
handler.f(msg)
}()
continue
return
}
conn.readCh <- msg
}

handleErr := func(err error) bool {
var mErr *json.UnmarshalTypeError
if errors.As(err, &mErr) {
// JSON decode errors are not fatal, log and proceed.
conn.log.Errorf("json decode error: %v", mErr)
return false
}
conn.handleReadError(err)
return true
}

reconnect := func() {
conn.reconnectCh <- struct{}{}
}

genericRead(ctx, read, handleResp, handleErr, reconnect, conn.cfg.AutoReconnect)
}

// keepAlive maintains an active websocket connection by reconnecting when
Expand Down Expand Up @@ -571,29 +614,6 @@ func (conn *wsConn) Connect(ctx context.Context) (*sync.WaitGroup, error) {
close(conn.readCh) // signal to MessageSource receivers that the wsConn is dead
}()

if interval := conn.cfg.AutoReconnect; interval > 0 {
conn.wg.Add(1)
go func() {
defer conn.wg.Done()
tick := time.After(interval)
for {
select {
case <-tick:
case <-ctx.Done():
return
}
lastConnect := time.Unix(conn.connected.Load(), 0)
if since := time.Since(lastConnect); since >= interval {
conn.reconnectCh <- struct{}{}
tick = time.After(interval)
} else {
tick = time.After(interval - since)
}
}

}()
}

return &conn.wg, err
}

Expand Down
2 changes: 1 addition & 1 deletion client/mm/libxc/binance.go
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,7 @@ func (bnc *binance) connectToMarketDataStream(ctx context.Context, baseID, quote
ConnectEventFunc: connectEventFunc,
Logger: bnc.log.SubLogger("BNCBOOK"),
RawHandler: bnc.handleMarketDataNote,
AutoReconnect: reconnectInterval,
AutoReconnect: &reconnectInterval,
})
if err != nil {
return err
Expand Down

0 comments on commit bd18aaa

Please sign in to comment.