Skip to content

Commit

Permalink
pkg/exchange: merge FeeRatePoller into StreamDataProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
bailantaotao committed Sep 30, 2024
1 parent eb7fcfc commit 246aa6d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 24 deletions.
8 changes: 4 additions & 4 deletions pkg/exchange/bybit/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ type Exchange struct {
// Because the bybit exchange does not provide a fee currency on traditional SPOT accounts, we need to query the marker
// fee rate to get the fee currency.
// https://bybit-exchange.github.io/docs/v5/enum#spot-fee-currency-instruction
feeRateProvider FeeRatePoller
FeeRatePoller
}

func New(key, secret string) (*Exchange, error) {
Expand All @@ -74,7 +74,7 @@ func New(key, secret string) (*Exchange, error) {
}
if len(key) > 0 && len(secret) > 0 {
client.Auth(key, secret)
ex.feeRateProvider = newFeeRatePoller(ex)
ex.FeeRatePoller = newFeeRatePoller(ex)

ctx, cancel := context.WithTimeoutCause(context.Background(), 5*time.Second, errors.New("query markets timeout"))
defer cancel()
Expand Down Expand Up @@ -437,7 +437,7 @@ func (e *Exchange) queryTrades(ctx context.Context, req *bybitapi.GetExecutionLi
}

for _, trade := range res.List {
feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.feeRateProvider, e.marketsInfo)
feeRate, err := pollAndGetFeeRate(ctx, trade.Symbol, e.FeeRatePoller, e.marketsInfo)
if err != nil {
return nil, fmt.Errorf("failed to get fee rate, err: %v", err)
}
Expand Down Expand Up @@ -607,5 +607,5 @@ func (e *Exchange) GetAllFeeRates(ctx context.Context) (bybitapi.FeeRates, error
}

func (e *Exchange) NewStream() types.Stream {
return NewStream(e.key, e.secret, e, e.feeRateProvider)
return NewStream(e.key, e.secret, e)
}
16 changes: 8 additions & 8 deletions pkg/exchange/bybit/market_info_poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ var (
)

type FeeRatePoller interface {
Start(ctx context.Context)
Get(symbol string) (SymbolFeeDetail, bool)
Poll(ctx context.Context) error
StartFeeRatePoller(ctx context.Context)
GetFeeRate(symbol string) (SymbolFeeDetail, bool)
PollFeeRate(ctx context.Context) error
}

type SymbolFeeDetail struct {
Expand Down Expand Up @@ -53,14 +53,14 @@ func newFeeRatePoller(marketInfoProvider MarketInfoProvider) *feeRatePoller {
}
}

func (p *feeRatePoller) Start(ctx context.Context) {
func (p *feeRatePoller) StartFeeRatePoller(ctx context.Context) {
p.once.Do(func() {
p.startLoop(ctx)
})
}

func (p *feeRatePoller) startLoop(ctx context.Context) {
err := p.Poll(ctx)
err := p.PollFeeRate(ctx)
if err != nil {
log.WithError(err).Warn("failed to initialize the fee rate, the ticker is scheduled to update it subsequently")
}
Expand All @@ -76,14 +76,14 @@ func (p *feeRatePoller) startLoop(ctx context.Context) {

return
case <-ticker.C:
if err := p.Poll(ctx); err != nil {
if err := p.PollFeeRate(ctx); err != nil {
log.WithError(err).Warn("failed to update fee rate")
}
}
}
}

func (p *feeRatePoller) Poll(ctx context.Context) error {
func (p *feeRatePoller) PollFeeRate(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
// the poll will be called frequently, so we need to check the last sync time.
Expand All @@ -105,7 +105,7 @@ func (p *feeRatePoller) Poll(ctx context.Context) error {
return nil
}

func (p *feeRatePoller) Get(symbol string) (SymbolFeeDetail, bool) {
func (p *feeRatePoller) GetFeeRate(symbol string) (SymbolFeeDetail, bool) {
p.mu.Lock()
defer p.mu.Unlock()

Expand Down
4 changes: 2 additions & 2 deletions pkg/exchange/bybit/market_info_poller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func Test_feeRatePoller_Get(t *testing.T) {
},
}

res, found := s.Get(symbol)
res, found := s.GetFeeRate(symbol)
assert.True(t, found)
assert.Equal(t, expFeeDetail, res)
})
Expand All @@ -165,7 +165,7 @@ func Test_feeRatePoller_Get(t *testing.T) {
symbolFeeDetail: map[string]SymbolFeeDetail{},
}

_, found := s.Get(symbol)
_, found := s.GetFeeRate(symbol)
assert.False(t, found)
})
}
10 changes: 5 additions & 5 deletions pkg/exchange/bybit/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type AccountBalanceProvider interface {
type StreamDataProvider interface {
MarketInfoProvider
AccountBalanceProvider
FeeRatePoller
}

//go:generate callbackgen -type Stream
Expand All @@ -70,14 +71,13 @@ type Stream struct {
tradeEventCallbacks []func(e []TradeEvent)
}

func NewStream(key, secret string, userDataProvider StreamDataProvider, poller FeeRatePoller) *Stream {
func NewStream(key, secret string, userDataProvider StreamDataProvider) *Stream {
stream := &Stream{
StandardStream: types.NewStandardStream(),
// pragma: allowlist nextline secret
key: key,
secret: secret,
streamDataProvider: userDataProvider,
feeRateProvider: poller,
}

stream.SetEndpointCreator(stream.createEndpoint)
Expand All @@ -91,7 +91,7 @@ func NewStream(key, secret string, userDataProvider StreamDataProvider, poller F
}

// get account fee rate
go stream.feeRateProvider.Start(ctx)
go stream.streamDataProvider.StartFeeRatePoller(ctx)

stream.marketsInfo, err = stream.streamDataProvider.QueryMarkets(ctx)
if err != nil {
Expand Down Expand Up @@ -440,15 +440,15 @@ func (s *Stream) handleKLineEvent(klineEvent KLineEvent) {
}

func pollAndGetFeeRate(ctx context.Context, symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) (SymbolFeeDetail, error) {
err := poller.Poll(ctx)
err := poller.PollFeeRate(ctx)
if err != nil {
return SymbolFeeDetail{}, err
}
return getFeeRate(symbol, poller, marketsInfo), nil
}

func getFeeRate(symbol string, poller FeeRatePoller, marketsInfo types.MarketMap) SymbolFeeDetail {
feeRate, found := poller.Get(symbol)
feeRate, found := poller.GetFeeRate(symbol)
if !found {
feeRate = SymbolFeeDetail{
FeeRate: bybitapi.FeeRate{
Expand Down
2 changes: 1 addition & 1 deletion pkg/exchange/bybit/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func getTestClientOrSkip(t *testing.T) *Stream {

exchange, err := New(key, secret)
assert.NoError(t, err)
return NewStream(key, secret, exchange, newFeeRatePoller(exchange))
return NewStream(key, secret, exchange)
}

func TestStream(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/exchange/bybit/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

func Test_parseWebSocketEvent(t *testing.T) {
t.Run("[public] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil, nil)
s := NewStream("", "", nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
Expand All @@ -26,7 +26,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})

t.Run("[public] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil, nil)
s := NewStream("", "", nil)
msg := `{"success":true,"ret_msg":"pong","conn_id":"a806f6c4-3608-4b6d-a225-9f5da975bc44","req_id":"b26704da-f5af-44c2-bdf7-935d6739e1a0","op":"ping"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
Expand All @@ -37,7 +37,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})

t.Run("[private] PingEvent without req id", func(t *testing.T) {
s := NewStream("", "", nil, nil)
s := NewStream("", "", nil)
msg := `{"op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
Expand All @@ -48,7 +48,7 @@ func Test_parseWebSocketEvent(t *testing.T) {
})

t.Run("[private] PingEvent with req id", func(t *testing.T) {
s := NewStream("", "", nil, nil)
s := NewStream("", "", nil)
msg := `{"req_id":"78d36b57-a142-47b7-9143-5843df77d44d","op":"pong","args":["1690884539181"],"conn_id":"civn4p1dcjmtvb69ome0-yrt1"}`
raw, err := s.parseWebSocketEvent([]byte(msg))
assert.NoError(t, err)
Expand Down

0 comments on commit 246aa6d

Please sign in to comment.