Skip to content

Commit

Permalink
Merge pull request #77 from jocon15/feature/parse-key-functions
Browse files Browse the repository at this point in the history
created base class rule key parser functions to use in indicator trig…
  • Loading branch information
jocon15 authored Dec 21, 2024
2 parents 6acba70 + 2df1c7f commit 4b684bc
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 252 deletions.
14 changes: 7 additions & 7 deletions src/StockBench/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_additional_days(self) -> int:
key = keys[i]
value = values[i]
for trigger in triggers:
if trigger.strategy_symbol in key:
if trigger.indicator_symbol in key:
num = trigger.additional_days(key, value)
if additional_days < num:
additional_days = num
Expand All @@ -95,11 +95,11 @@ def add_indicator_data(self, data_manager) -> None:
# find all buy algorithm and add their indicator to the data
for key in self.strategy[BUY_SIDE].keys():
for trigger in triggers:
if trigger.strategy_symbol in key:
if trigger.indicator_symbol in key:
trigger.add_to_data(key, self.strategy[BUY_SIDE][key], BUY_SIDE, data_manager)
elif AND_KEY in key:
for inner_key in self.strategy[BUY_SIDE][key].keys():
if trigger.strategy_symbol in inner_key:
if trigger.indicator_symbol in inner_key:
trigger.add_to_data(inner_key, self.strategy[BUY_SIDE][key][inner_key], BUY_SIDE,
data_manager)

Expand All @@ -109,11 +109,11 @@ def add_indicator_data(self, data_manager) -> None:
# find all sell algorithm and add their indicator to the data
for key in self.strategy[SELL_SIDE].keys():
for trigger in triggers:
if trigger.strategy_symbol in key:
if trigger.indicator_symbol in key:
trigger.add_to_data(key, self.strategy[SELL_SIDE][key], SELL_SIDE, data_manager)
elif AND_KEY in key:
for inner_key in self.strategy[SELL_SIDE][key].keys():
if trigger.strategy_symbol in inner_key:
if trigger.indicator_symbol in inner_key:
trigger.add_to_data(inner_key, self.strategy[SELL_SIDE][key][inner_key], SELL_SIDE,
data_manager)

Expand Down Expand Up @@ -244,7 +244,7 @@ def __handle_and_triggers(self, triggers: List[Trigger], data_manager: DataManag
key_matched_with_trigger = False
# check all algorithm
for trigger in triggers:
if trigger.strategy_symbol in inner_key:
if trigger.indicator_symbol in inner_key:
key_matched_with_trigger = True
trigger_hit = trigger.check_trigger(
inner_key,
Expand Down Expand Up @@ -280,7 +280,7 @@ def __handle_or_triggers(self, triggers: List[Trigger], data_manager: DataManage
key_matched_with_trigger = False
# check all algorithm
for trigger in triggers:
if trigger.strategy_symbol in key:
if trigger.indicator_symbol in key:
key_matched_with_trigger = True
trigger_hit = trigger.check_trigger(
key,
Expand Down
121 changes: 119 additions & 2 deletions src/StockBench/indicator/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class Trigger:
SELL = 1
AGNOSTIC = 2

def __init__(self, strategy_symbol: str, side: str):
self.strategy_symbol = strategy_symbol
def __init__(self, indicator_symbol: str, side: str):
self.indicator_symbol = indicator_symbol
self.__side = side

def get_side(self):
Expand Down Expand Up @@ -63,6 +63,123 @@ def _parse_rule_value(self, rule_value: str, data_manager: DataManager,

return operator, trigger_value

@staticmethod
def _parse_rule_key(rule_key: str, indicator_symbol: str, data_manager: DataManager,
current_day_index: int) -> float:
"""Translates a complex rule key for an indicator value where the indicator has a default value.
Can have 0, 1, or 2 number groupings.
"""
rule_key_number_groups = Trigger.find_all_nums_in_str(rule_key)
if len(rule_key_number_groups) == 0:
# rule key does not define an indicator length (use default)
if SLOPE_SYMBOL in rule_key:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} does not contain '
f'enough number groupings!')
indicator_value = float(data_manager.get_data_point(indicator_symbol, current_day_index))
elif len(rule_key_number_groups) == 1:
if SLOPE_SYMBOL in rule_key:
# make sure the number is after the slope emblem and not the RSI emblem
if rule_key.split(str(rule_key_number_groups))[0] == indicator_symbol + SLOPE_SYMBOL:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} does not contain '
f'a slope value!')
# rule key defines an indicator length (not using default)
column_title = f'{indicator_symbol}{int(rule_key_number_groups[0])}'
indicator_value = float(data_manager.get_data_point(column_title, current_day_index))
elif len(rule_key_number_groups) == 2:
column_title = f'{indicator_symbol}{int(rule_key_number_groups[0])}'
# 2 number groupings suggests the $slope indicator is being used
if SLOPE_SYMBOL in rule_key:
slope_window_length = int(rule_key_number_groups[1])

# data request length is window - 1 to account for the current day index being a part of the window
slope_data_request_length = slope_window_length - 1

indicator_value = Trigger.calculate_slope(
float(data_manager.get_data_point(column_title, current_day_index)),
float(data_manager.get_data_point(column_title, current_day_index - slope_data_request_length)),
slope_window_length
)
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains too many number '
f'groupings! Are you missing a $slope emblem?')
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains invalid number '
f'groupings!')

return indicator_value

@staticmethod
def _parse_rule_key_no_default_indicator_length(rule_key: str, indicator_symbol: str, data_manager: DataManager,
current_day_index: int) -> float:
"""Translates a complex rule key for an indicator value where the indicator DOES NOT have a default value.
Can have 1, or 2 number groupings.
"""
key_number_groupings = Trigger.find_all_nums_in_str(rule_key)

if len(key_number_groupings) == 1:
if SLOPE_SYMBOL in rule_key:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} does not contain '
f'enough number groupings!')
column_title = f'{indicator_symbol}{int(key_number_groupings[0])}'
indicator_value = float(data_manager.get_data_point(column_title, current_day_index))
elif len(key_number_groupings) == 2:
column_title = f'{indicator_symbol}{int(key_number_groupings[0])}'
# 2 number groupings suggests the $slope indicator is being used
if SLOPE_SYMBOL in rule_key:
slope_window_length = int(key_number_groupings[1])

# data request length is window - 1 to account for the current day index being a part of the window
slope_data_request_length = slope_window_length - 1

indicator_value = Trigger.calculate_slope(
float(data_manager.get_data_point(column_title, current_day_index)),
float(data_manager.get_data_point(column_title, current_day_index - slope_data_request_length)),
slope_window_length
)
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains too many number '
f'groupings! Are you missing a $slope emblem?')
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains invalid number '
f'groupings!')

return indicator_value

@staticmethod
def _parse_rule_key_no_indicator_length(rule_key: str, indicator_symbol: str, data_manager: DataManager,
current_day_index: int) -> float:
"""Parser for parsing the key into the indicator value."""
key_number_groupings = Trigger.find_all_nums_in_str(rule_key)

# MACD can only have slope emblem therefore 1 or 0 number groupings are acceptable
if len(key_number_groupings) == 0:
if SLOPE_SYMBOL in rule_key:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} does not contain'
f' enough number groupings!')
indicator_value = float(data_manager.get_data_point(indicator_symbol, current_day_index))
elif len(key_number_groupings) == 1:
# 1 number grouping suggests the $slope indicator is being used
if SLOPE_SYMBOL in rule_key:
slope_window_length = int(key_number_groupings[0])

# data request length is window - 1 to account for the current day index being a part of the window
slope_data_request_length = slope_window_length - 1

indicator_value = Trigger.calculate_slope(
float(data_manager.get_data_point(indicator_symbol, current_day_index)),
float(data_manager.get_data_point(indicator_symbol, current_day_index -
slope_data_request_length)),
slope_window_length
)
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains too many number '
f'groupings! Are you missing a $slope emblem?')
else:
raise StrategyIndicatorError(f'{indicator_symbol} rule key: {rule_key} contains '
f'invalid number groupings!')

return indicator_value

@staticmethod
def _add_trigger_column(column_name: str, trigger_value: float, data_manager: DataManager):
"""Add a trigger value to the df."""
Expand Down
8 changes: 4 additions & 4 deletions src/StockBench/indicators/candlestick_color/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class CandlestickColorTrigger(Trigger):
def __init__(self, strategy_symbol):
super().__init__(strategy_symbol, side=Trigger.AGNOSTIC)
def __init__(self, indicator_symbol):
super().__init__(indicator_symbol, side=Trigger.AGNOSTIC)

def additional_days(self, rule_key, value_value) -> int:
"""Calculate the additional days required.
Expand All @@ -19,7 +19,7 @@ def additional_days(self, rule_key, value_value) -> int:
value_value (any): The value from the strategy.
"""
if len(value_value.keys()) == 0:
raise StrategyIndicatorError(f'{self.strategy_symbol} key: {rule_key} must have at least one color child '
raise StrategyIndicatorError(f'{self.indicator_symbol} key: {rule_key} must have at least one color child '
f'key')

additional_days = 0
Expand Down Expand Up @@ -58,7 +58,7 @@ def check_trigger(self, rule_key, rule_value, data_manager, position, current_da
key_count = len(rule_value)

if key_count == 0:
raise StrategyIndicatorError(f'{self.strategy_symbol} key: {rule_key} must have at least one color child '
raise StrategyIndicatorError(f'{self.indicator_symbol} key: {rule_key} must have at least one color child '
f'key')

trigger_colors = [rule_value[value_key] for value_key in sorted(rule_value.keys())]
Expand Down
51 changes: 9 additions & 42 deletions src/StockBench/indicators/ema/trigger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import statistics
from StockBench.constants import *
from StockBench.indicator.trigger import Trigger
from StockBench.indicator.exceptions import StrategyIndicatorError
from StockBench.simulation_data.data_manager import DataManager
Expand All @@ -9,8 +8,8 @@


class EMATrigger(Trigger):
def __init__(self, strategy_symbol):
super().__init__(strategy_symbol, side=Trigger.AGNOSTIC)
def __init__(self, indicator_symbol):
super().__init__(indicator_symbol, side=Trigger.AGNOSTIC)

def additional_days(self, rule_key, value_value) -> int:
"""Calculate the additional days required.
Expand All @@ -23,7 +22,7 @@ def additional_days(self, rule_key, value_value) -> int:
nums = list(map(int, self.find_all_nums_in_str(rule_key)))
if nums:
return max(nums)
raise StrategyIndicatorError(f'{self.strategy_symbol} key: {rule_key} must have an indicator length!')
raise StrategyIndicatorError(f'{self.indicator_symbol} key: {rule_key} must have an indicator length!')

def add_to_data(self, rule_key, rule_value, side, data_manager):
"""Add data to the dataframe.
Expand All @@ -39,7 +38,7 @@ def add_to_data(self, rule_key, rule_value, side, data_manager):
indicator_length = int(nums[0])
self.__add_ema(indicator_length, data_manager)
else:
raise StrategyIndicatorError(f'{self.strategy_symbol} key: {rule_key} must have an indicator length!')
raise StrategyIndicatorError(f'{self.indicator_symbol} key: {rule_key} must have an indicator length!')

def check_trigger(self, rule_key, rule_value, data_manager, position, current_day_index) -> bool:
"""Trigger logic for EMA.
Expand All @@ -54,52 +53,20 @@ def check_trigger(self, rule_key, rule_value, data_manager, position, current_da
return:
bool: True if a trigger was hit.
"""
log.debug(f'Checking {self.strategy_symbol} algorithm: {rule_key}...')
log.debug(f'Checking {self.indicator_symbol} algorithm: {rule_key}...')

indicator_value = self.__parse_key(rule_key, data_manager, current_day_index)
indicator_value = Trigger._parse_rule_key_no_default_indicator_length(rule_key, self.indicator_symbol, data_manager,
current_day_index)

operator, trigger_value = self._parse_rule_value(rule_value, data_manager, current_day_index)

log.debug(f'{self.strategy_symbol} algorithm: {rule_key} checked successfully')
log.debug(f'{self.indicator_symbol} algorithm: {rule_key} checked successfully')

return Trigger.basic_trigger_check(indicator_value, operator, trigger_value)

def __parse_key(self, rule_key: str, data_manager: DataManager, current_day_index: int) -> float:
"""Parser for parsing the key into the indicator value."""
key_number_groupings = self.find_all_nums_in_str(rule_key)

if len(key_number_groupings) == 1:
if SLOPE_SYMBOL in rule_key:
raise StrategyIndicatorError(f'{self.strategy_symbol} rule key: {rule_key} does not contain '
f'enough number groupings!')
column_title = f'{self.strategy_symbol}{int(key_number_groupings[0])}'
indicator_value = float(data_manager.get_data_point(column_title, current_day_index))
elif len(key_number_groupings) == 2:
column_title = f'{self.strategy_symbol}{int(key_number_groupings[0])}'
# 2 number groupings suggests the $slope indicator is being used
if SLOPE_SYMBOL in rule_key:
slope_window_length = int(key_number_groupings[1])

# data request length is window - 1 to account for the current day index being a part of the window
slope_data_request_length = slope_window_length - 1

indicator_value = self.calculate_slope(
float(data_manager.get_data_point(column_title, current_day_index)),
float(data_manager.get_data_point(column_title, current_day_index - slope_data_request_length)),
slope_window_length
)
else:
raise StrategyIndicatorError(f'{self.strategy_symbol} rule key: {rule_key} contains too many number '
f'groupings! Are you missing a $slope emblem?')
else:
raise StrategyIndicatorError(f'{self.strategy_symbol} rule key: {rule_key} contains invalid number '
f'groupings!')

return indicator_value

def __add_ema(self, length: int, data_manager: DataManager):
"""Pre-calculate the EMA values and add them to the df."""
column_title = f'{self.strategy_symbol}{length}'
column_title = f'{self.indicator_symbol}{length}'

# if we already have EMA values in the df, we don't need to add them again
for col_name in data_manager.get_column_names():
Expand Down
Loading

0 comments on commit 4b684bc

Please sign in to comment.