diff --git a/pkg/indicator/boll.go b/pkg/indicator/boll.go index ebc810fe52..3fa49030cb 100644 --- a/pkg/indicator/boll.go +++ b/pkg/indicator/boll.go @@ -3,9 +3,6 @@ package indicator import ( "time" - log "github.com/sirupsen/logrus" - "gonum.org/v1/gonum/stat" - "github.com/c9s/bbgo/pkg/types" ) @@ -24,13 +21,15 @@ Bollinger Bands Technical indicator guide: //go:generate callbackgen -type BOLL type BOLL struct { + types.SeriesBase types.IntervalWindow - // times of Std, generally it's 2 + // K is the multiplier of Std, generally it's 2 K float64 - SMA types.Float64Slice - StdDev types.Float64Slice + SMA *SMA + StdDev *StdDev + UpBand types.Float64Slice DownBand types.Float64Slice @@ -50,11 +49,11 @@ func (inc *BOLL) GetDownBand() types.SeriesExtend { } func (inc *BOLL) GetSMA() types.SeriesExtend { - return types.NewSeries(&inc.SMA) + return types.NewSeries(inc.SMA) } func (inc *BOLL) GetStdDev() types.SeriesExtend { - return types.NewSeries(&inc.StdDev) + return inc.StdDev } func (inc *BOLL) LastUpBand() float64 { @@ -73,64 +72,49 @@ func (inc *BOLL) LastDownBand() float64 { return inc.DownBand[len(inc.DownBand)-1] } -func (inc *BOLL) LastStdDev() float64 { - if len(inc.StdDev) == 0 { - return 0.0 - } - - return inc.StdDev[len(inc.StdDev)-1] -} - -func (inc *BOLL) LastSMA() float64 { - if len(inc.SMA) > 0 { - return inc.SMA[len(inc.SMA)-1] - } - return 0.0 -} - -func (inc *BOLL) CalculateAndUpdate(kLines []types.KLine) { - if len(kLines) < inc.Window { - return - } - - var index = len(kLines) - 1 - var kline = kLines[index] - - if inc.EndTime != zeroTime && kline.EndTime.Before(inc.EndTime) { - return +func (inc *BOLL) Update(value float64) { + if inc.SMA == nil { + inc.SeriesBase.Series = inc + inc.SMA = &SMA{IntervalWindow: inc.IntervalWindow} } - var recentK = kLines[index-(inc.Window-1) : index+1] - sma, err := calculateSMA(recentK, inc.Window, KLineClosePriceMapper) - if err != nil { - log.WithError(err).Error("SMA error") - return + if inc.StdDev == nil { + inc.StdDev = &StdDev{IntervalWindow: inc.IntervalWindow} } - inc.SMA.Push(sma) + inc.SMA.Update(value) + inc.StdDev.Update(value) - var prices []float64 - for _, k := range recentK { - prices = append(prices, k.Close.Float64()) - } - - var std = stat.StdDev(prices, nil) - inc.StdDev.Push(std) - - var band = inc.K * std + var sma = inc.SMA.Last() + var stdDev = inc.StdDev.Last() + var band = inc.K * stdDev var upBand = sma + band - inc.UpBand.Push(upBand) - var downBand = sma - band + + inc.UpBand.Push(upBand) inc.DownBand.Push(downBand) +} - // update end time - inc.EndTime = kLines[index].EndTime.Time() +func (inc *BOLL) PushK(k types.KLine) { + inc.Update(k.Close.Float64()) +} - // log.Infof("update boll: sma=%f, up=%f, down=%f", sma, upBand, downBand) +func (inc *BOLL) CalculateAndUpdate(allKLines []types.KLine) { + var last = allKLines[len(allKLines)-1] + + if inc.SMA == nil { + for _, k := range allKLines { + if inc.EndTime != zeroTime && k.EndTime.Before(inc.EndTime) { + continue + } + inc.PushK(k) + } + } else { + inc.PushK(last) + } - inc.EmitUpdate(sma, upBand, downBand) + inc.EmitUpdate(inc.SMA.Last(), inc.UpBand.Last(), inc.DownBand.Last()) } func (inc *BOLL) handleKLineWindowUpdate(interval types.Interval, window types.KLineWindow) { diff --git a/pkg/indicator/boll_test.go b/pkg/indicator/boll_test.go index 2d51e08b14..7bc8c574c2 100644 --- a/pkg/indicator/boll_test.go +++ b/pkg/indicator/boll_test.go @@ -4,9 +4,10 @@ import ( "encoding/json" "testing" + "github.com/stretchr/testify/assert" + "github.com/c9s/bbgo/pkg/fixedpoint" "github.com/c9s/bbgo/pkg/types" - "github.com/stretchr/testify/assert" ) /* @@ -60,8 +61,8 @@ func TestBOLL(t *testing.T) { t.Run(tt.name, func(t *testing.T) { boll := BOLL{IntervalWindow: types.IntervalWindow{Window: tt.window}, K: tt.k} boll.CalculateAndUpdate(tt.kLines) - assert.InDelta(t, tt.up, boll.LastUpBand(), Delta) - assert.InDelta(t, tt.down, boll.LastDownBand(), Delta) + assert.InDelta(t, tt.up, boll.UpBand.Last(), Delta) + assert.InDelta(t, tt.down, boll.DownBand.Last(), Delta) }) } diff --git a/pkg/indicator/stddev.go b/pkg/indicator/stddev.go new file mode 100644 index 0000000000..3c6a5bdca3 --- /dev/null +++ b/pkg/indicator/stddev.go @@ -0,0 +1,85 @@ +package indicator + +import ( + "time" + + "github.com/c9s/bbgo/pkg/types" +) + +//go:generate callbackgen -type StdDev +type StdDev struct { + types.SeriesBase + types.IntervalWindow + Values types.Float64Slice + rawValues *types.Queue + + EndTime time.Time + updateCallbacks []func(value float64) +} + +func (inc *StdDev) Last() float64 { + if inc.Values.Length() == 0 { + return 0.0 + } + return inc.Values.Last() +} + +func (inc *StdDev) Index(i int) float64 { + if i >= inc.Values.Length() { + return 0.0 + } + + return inc.Values.Index(i) +} + +func (inc *StdDev) Length() int { + return inc.Values.Length() +} + +var _ types.SeriesExtend = &StdDev{} + +func (inc *StdDev) Update(value float64) { + if inc.rawValues == nil { + inc.rawValues = types.NewQueue(inc.Window) + inc.SeriesBase.Series = inc + } + + inc.rawValues.Update(value) + + var std = inc.rawValues.Stdev() + inc.Values.Push(std) +} + +func (inc *StdDev) PushK(k types.KLine) { + inc.Update(k.Close.Float64()) + inc.EndTime = k.EndTime.Time() +} + +func (inc *StdDev) CalculateAndUpdate(allKLines []types.KLine) { + var last = allKLines[len(allKLines)-1] + + if inc.rawValues == nil { + for _, k := range allKLines { + if inc.EndTime != zeroTime && k.EndTime.Before(inc.EndTime) { + continue + } + inc.PushK(k) + } + } else { + inc.PushK(last) + } + + inc.EmitUpdate(inc.Values.Last()) +} + +func (inc *StdDev) handleKLineWindowUpdate(interval types.Interval, window types.KLineWindow) { + if inc.Interval != interval { + return + } + + inc.CalculateAndUpdate(window) +} + +func (inc *StdDev) Bind(updater KLineWindowUpdater) { + updater.OnKLineWindowUpdate(inc.handleKLineWindowUpdate) +} diff --git a/pkg/indicator/stddev_callbacks.go b/pkg/indicator/stddev_callbacks.go new file mode 100644 index 0000000000..745f006eeb --- /dev/null +++ b/pkg/indicator/stddev_callbacks.go @@ -0,0 +1,15 @@ +// Code generated by "callbackgen -type StdDev"; DO NOT EDIT. + +package indicator + +import () + +func (inc *StdDev) OnUpdate(cb func(value float64)) { + inc.updateCallbacks = append(inc.updateCallbacks, cb) +} + +func (inc *StdDev) EmitUpdate(value float64) { + for _, cb := range inc.updateCallbacks { + cb(value) + } +} diff --git a/pkg/strategy/bollmaker/strategy.go b/pkg/strategy/bollmaker/strategy.go index ff3c31eeea..564f351249 100644 --- a/pkg/strategy/bollmaker/strategy.go +++ b/pkg/strategy/bollmaker/strategy.go @@ -278,9 +278,9 @@ func (s *Strategy) placeOrders(ctx context.Context, midPrice fixedpoint.Value, k baseBalance, hasBaseBalance := balances[s.Market.BaseCurrency] quoteBalance, hasQuoteBalance := balances[s.Market.QuoteCurrency] - downBand := s.defaultBoll.LastDownBand() - upBand := s.defaultBoll.LastUpBand() - sma := s.defaultBoll.LastSMA() + downBand := s.defaultBoll.DownBand.Last() + upBand := s.defaultBoll.UpBand.Last() + sma := s.defaultBoll.SMA.Last() log.Infof("%s bollinger band: up %f sma %f down %f", s.Symbol, upBand, sma, downBand) bandPercentage := calculateBandPercentage(upBand, downBand, sma, midPrice.Float64()) @@ -349,7 +349,7 @@ func (s *Strategy) placeOrders(ctx context.Context, midPrice fixedpoint.Value, k // WHEN: price breaks the upper band (price > window 2) == strongUpTrend // THEN: we apply strongUpTrend skew if s.TradeInBand { - if !inBetween(midPrice.Float64(), s.neutralBoll.LastDownBand(), s.neutralBoll.LastUpBand()) { + if !inBetween(midPrice.Float64(), s.neutralBoll.DownBand.Last(), s.neutralBoll.UpBand.Last()) { log.Infof("tradeInBand is set, skip placing orders when the price is outside of the band") return } @@ -402,7 +402,7 @@ func (s *Strategy) placeOrders(ctx context.Context, midPrice fixedpoint.Value, k canSell = false } - if s.BuyBelowNeutralSMA && midPrice.Float64() > s.neutralBoll.LastSMA() { + if s.BuyBelowNeutralSMA && midPrice.Float64() > s.neutralBoll.SMA.Last() { canBuy = false } diff --git a/pkg/strategy/bollmaker/trend.go b/pkg/strategy/bollmaker/trend.go index 654ac4cce4..33167967b5 100644 --- a/pkg/strategy/bollmaker/trend.go +++ b/pkg/strategy/bollmaker/trend.go @@ -12,7 +12,7 @@ const ( ) func detectPriceTrend(inc *indicator.BOLL, price float64) PriceTrend { - if inBetween(price, inc.LastDownBand(), inc.LastUpBand()) { + if inBetween(price, inc.DownBand.Last(), inc.UpBand.Last()) { return NeutralTrend } diff --git a/pkg/strategy/rsmaker/strategy.go b/pkg/strategy/rsmaker/strategy.go index 1eada96c2a..0a860bea07 100644 --- a/pkg/strategy/rsmaker/strategy.go +++ b/pkg/strategy/rsmaker/strategy.go @@ -336,9 +336,9 @@ func (s *Strategy) placeOrders(ctx context.Context, midPrice fixedpoint.Value, k // baseBalance, hasBaseBalance := balances[s.Market.BaseCurrency] // quoteBalance, hasQuoteBalance := balances[s.Market.QuoteCurrency] - downBand := s.defaultBoll.LastDownBand() - upBand := s.defaultBoll.LastUpBand() - sma := s.defaultBoll.LastSMA() + downBand := s.defaultBoll.DownBand.Last() + upBand := s.defaultBoll.UpBand.Last() + sma := s.defaultBoll.SMA.Last() log.Infof("bollinger band: up %f sma %f down %f", upBand, sma, downBand) bandPercentage := calculateBandPercentage(upBand, downBand, sma, midPrice.Float64()) diff --git a/pkg/strategy/xmaker/strategy.go b/pkg/strategy/xmaker/strategy.go index 5c5c8aa784..b469591c8b 100644 --- a/pkg/strategy/xmaker/strategy.go +++ b/pkg/strategy/xmaker/strategy.go @@ -305,8 +305,8 @@ func (s *Strategy) updateQuote(ctx context.Context, orderExecutionRouter bbgo.Or var pips = s.Pips if s.EnableBollBandMargin { - lastDownBand := fixedpoint.NewFromFloat(s.boll.LastDownBand()) - lastUpBand := fixedpoint.NewFromFloat(s.boll.LastUpBand()) + lastDownBand := fixedpoint.NewFromFloat(s.boll.DownBand.Last()) + lastUpBand := fixedpoint.NewFromFloat(s.boll.UpBand.Last()) if lastUpBand.IsZero() || lastDownBand.IsZero() { log.Warnf("bollinger band value is zero, skipping")