diff --git a/pkg/indicator/obv.go b/pkg/indicator/obv.go index 507eeea9ac..2ca18a63f4 100644 --- a/pkg/indicator/obv.go +++ b/pkg/indicator/obv.go @@ -32,15 +32,13 @@ func (inc *OBV) update(kLine types.KLine, priceF KLinePriceMapper) { return } - preOBV := inc.Values[len(inc.Values)-1] - var sign float64 = 0.0 if volume > inc.PrePrice { sign = 1.0 } else if volume < inc.PrePrice { sign = -1.0 } - obv := preOBV + sign*volume + obv := inc.Last() + sign*volume inc.Values.Push(obv) } diff --git a/pkg/indicator/obv_test.go b/pkg/indicator/obv_test.go new file mode 100644 index 0000000000..2085306007 --- /dev/null +++ b/pkg/indicator/obv_test.go @@ -0,0 +1,47 @@ +package indicator + +import ( + "reflect" + "testing" + + "github.com/c9s/bbgo/pkg/types" +) + +func Test_calculateAndUpdate(t *testing.T) { + buildKLines := func(prices, volumes []float64) (kLines []types.KLine) { + for i, p := range prices { + kLines = append(kLines, types.KLine{High: p, Low: p, Close: p, Volume: volumes[i]}) + } + return kLines + } + + tests := []struct { + name string + kLines []types.KLine + window int + want Float64Slice + }{ + { + name: "trivial_case", + kLines: buildKLines([]float64{0}, []float64{1}), + window: 0, + want: Float64Slice{1.0}, + }, + { + name: "easy_case", + kLines: buildKLines([]float64{3, 2, 1, 4}, []float64{3, 2, 2, 6}), + window: 0, + want: Float64Slice{3, 1, -1, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + obv := OBV{IntervalWindow: types.IntervalWindow{Window: tt.window}} + obv.calculateAndUpdate(tt.kLines) + if !reflect.DeepEqual(obv.Values, tt.want) { + t.Errorf("calculateAndUpdate() = %v, want %v", obv.Values, tt.want) + } + }) + } +}