Skip to content

Commit 172767c

Browse files
ivaigultkernc
andcommitted
BUG: Allow multiple names for vector indicators (#382)
Previously we only allowed one name per vector indicator: def _my_indicator(open, close): return tuple( _my_indicator_one(open, close), _my_indicator_two(open, close), ) self.I( _my_indicator, # One name is used to describe two values name="My Indicator", self.data.Open, self.data.Close ) Now, the user can supply two (or more) names to annotate each value individually. The names will be shown in the plot legend. The following is now valid: self.I( _my_indicator, # Two names can now be passed name=["My Indicator One", "My Indicator Two"], self.data.Open, self.data.Close ) Co-authored-by: kernc <kerncece@gmail.com>
1 parent 0ce24d8 commit 172767c

File tree

3 files changed

+71
-13
lines changed

3 files changed

+71
-13
lines changed

backtesting/_plotting.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import warnings
55
from colorsys import hls_to_rgb, rgb_to_hls
6-
from itertools import cycle, combinations
6+
from itertools import cycle, combinations, repeat
77
from functools import partial
88
from typing import Callable, List, Union
99

@@ -537,10 +537,22 @@ def __eq__(self, other):
537537
colors = value._opts['color']
538538
colors = colors and cycle(_as_list(colors)) or (
539539
cycle([next(ohlc_colors)]) if is_overlay else colorgen())
540-
legend_label = LegendStr(value.name)
541-
for j, arr in enumerate(value, 1):
540+
541+
tooltip_label = value.name if isinstance(value.name, str) else ", ".join(value.name)
542+
543+
if len(value) == 1:
544+
legend_labels = [LegendStr(item) for item in _as_list(value.name)]
545+
elif isinstance(value.name, str):
546+
legend_labels = [
547+
LegendStr(f"{name}[{index}]")
548+
for index, name in enumerate(repeat(value.name, len(value)))
549+
]
550+
else:
551+
legend_labels = [LegendStr(item) for item in value.name]
552+
553+
for j, arr in enumerate(value):
542554
color = next(colors)
543-
source_name = f'{legend_label}_{i}_{j}'
555+
source_name = f'{legend_labels[j]}_{i}_{j}'
544556
if arr.dtype == bool:
545557
arr = arr.astype(int)
546558
source.add(arr, source_name)
@@ -550,24 +562,24 @@ def __eq__(self, other):
550562
if is_scatter:
551563
fig.scatter(
552564
'index', source_name, source=source,
553-
legend_label=legend_label, color=color,
565+
legend_label=legend_labels[j], color=color,
554566
line_color='black', fill_alpha=.8,
555567
marker='circle', radius=BAR_WIDTH / 2 * 1.5)
556568
else:
557569
fig.line(
558570
'index', source_name, source=source,
559-
legend_label=legend_label, line_color=color,
571+
legend_label=legend_labels[j], line_color=color,
560572
line_width=1.3)
561573
else:
562574
if is_scatter:
563575
r = fig.scatter(
564576
'index', source_name, source=source,
565-
legend_label=LegendStr(legend_label), color=color,
577+
legend_label=legend_labels[j], color=color,
566578
marker='circle', radius=BAR_WIDTH / 2 * .9)
567579
else:
568580
r = fig.line(
569581
'index', source_name, source=source,
570-
legend_label=LegendStr(legend_label), line_color=color,
582+
legend_label=legend_labels[j], line_color=color,
571583
line_width=1.3)
572584
# Add dashed centerline just because
573585
mean = float(pd.Series(arr).mean())
@@ -578,9 +590,9 @@ def __eq__(self, other):
578590
line_color='#666666', line_dash='dashed',
579591
line_width=.5))
580592
if is_overlay:
581-
ohlc_tooltips.append((legend_label, NBSP.join(tooltips)))
593+
ohlc_tooltips.append((tooltip_label, NBSP.join(tooltips)))
582594
else:
583-
set_tooltips(fig, [(legend_label, NBSP.join(tooltips))], vline=True, renderers=[r])
595+
set_tooltips(fig, [(tooltip_label, NBSP.join(tooltips))], vline=True, renderers=[r])
584596
# If the sole indicator line on this figure,
585597
# have the legend only contain text without the glyph
586598
if len(value) == 1:

backtesting/backtesting.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def I(self, # noqa: E743
9090
same length as `backtesting.backtesting.Strategy.data`.
9191
9292
In the plot legend, the indicator is labeled with
93-
function name, unless `name` overrides it.
93+
function name, unless `name` overrides it. If `func` returns
94+
multiple arrays, `name` can be a sequence of strings, and
95+
its size must agree with the number of arrays returned.
9496
9597
If `plot` is `True`, the indicator is plotted on the resulting
9698
`backtesting.backtesting.Backtest.plot`.
@@ -115,13 +117,21 @@ def I(self, # noqa: E743
115117
def init():
116118
self.sma = self.I(ta.SMA, self.data.Close, self.n_sma)
117119
"""
120+
def _format_name(name: str) -> str:
121+
return name.format(*map(_as_str, args),
122+
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
123+
118124
if name is None:
119125
params = ','.join(filter(None, map(_as_str, chain(args, kwargs.values()))))
120126
func_name = _as_str(func)
121127
name = (f'{func_name}({params})' if params else f'{func_name}')
128+
elif isinstance(name, str):
129+
name = _format_name(name)
130+
elif try_(lambda: all(isinstance(item, str) for item in name), False):
131+
name = [_format_name(item) for item in name]
122132
else:
123-
name = name.format(*map(_as_str, args),
124-
**dict(zip(kwargs.keys(), map(_as_str, kwargs.values()))))
133+
raise TypeError(f'Unexpected `name=` type {type(name)}; expected `str` or '
134+
'`Sequence[str]`')
125135

126136
try:
127137
value = func(*args, **kwargs)
@@ -139,6 +149,11 @@ def init():
139149
if is_arraylike and np.argmax(value.shape) == 0:
140150
value = value.T
141151

152+
if isinstance(name, list) and (np.atleast_2d(value).shape[0] != len(name)):
153+
raise ValueError(
154+
f'Length of `name=` ({len(name)}) must agree with the number '
155+
f'of arrays the indicator returns ({value.shape[0]}).')
156+
142157
if not is_arraylike or not 1 <= value.ndim <= 2 or value.shape[-1] != len(self._data.Close):
143158
raise ValueError(
144159
'Indicators must return (optionally a tuple of) numpy.arrays of same '

backtesting/test/_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,37 @@ def test_resample(self):
755755
# Give browser time to open before tempfile is removed
756756
time.sleep(1)
757757

758+
def test_indicator_name(self):
759+
test_self = self
760+
761+
class S(Strategy):
762+
def init(self):
763+
def _SMA():
764+
return SMA(self.data.Close, 5), SMA(self.data.Close, 10)
765+
766+
test_self.assertRaises(TypeError, self.I, _SMA, name=42)
767+
test_self.assertRaises(ValueError, self.I, _SMA, name=("SMA One", ))
768+
test_self.assertRaises(
769+
ValueError, self.I, _SMA, name=("SMA One", "SMA Two", "SMA Three"))
770+
771+
for overlay in (True, False):
772+
self.I(SMA, self.data.Close, 5, overlay=overlay)
773+
self.I(SMA, self.data.Close, 5, name="My SMA", overlay=overlay)
774+
self.I(SMA, self.data.Close, 5, name=("My SMA", ), overlay=overlay)
775+
self.I(_SMA, overlay=overlay)
776+
self.I(_SMA, name="My SMA", overlay=overlay)
777+
self.I(_SMA, name=("SMA One", "SMA Two"), overlay=overlay)
778+
779+
def next(self):
780+
pass
781+
782+
bt = Backtest(GOOG, S)
783+
bt.run()
784+
with _tempfile() as f:
785+
bt.plot(filename=f,
786+
plot_drawdown=False, plot_equity=False, plot_pl=False, plot_volume=False,
787+
open_browser=False)
788+
758789
def test_indicator_color(self):
759790
class S(Strategy):
760791
def init(self):

0 commit comments

Comments
 (0)