Skip to content

Commit

Permalink
Merge pull request #78 from jocon15/revision/single-indicator-calcula…
Browse files Browse the repository at this point in the history
…tions

Revision/single indicator calculations
  • Loading branch information
jocon15 authored Dec 22, 2024
2 parents 4b684bc + 6d503e3 commit 8cb24de
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 86 deletions.
18 changes: 4 additions & 14 deletions src/StockBench/charting/singular/singular_charting_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,7 @@ class SingularChartingEngine(ChartingEngine):
def build_indicator_chart(df: DataFrame, symbol: str, available_indicators: List[IndicatorInterface],
show_volume: bool,
save_option=ChartingEngine.TEMP_SAVE) -> str:
"""Multi-plot chart for singular simulation indicators.
Args:
df: The full DataFrame post-simulation.
symbol: The symbol the simulation was run on.
available_indicators: The list of indicators.
save_option: Save the chart.
Return:
(str): The filepath of the chart
"""
"""Multi-plot chart for singular simulation indicators."""
subplot_objects, subplot_types = SingularChartingEngine.__get_subplot_objects_and_types(df,
available_indicators)

Expand All @@ -56,9 +46,9 @@ def build_indicator_chart(df: DataFrame, symbol: str, available_indicators: List
def build_account_value_bar_chart(df: DataFrame, symbol: str, save_option=ChartingEngine.TEMP_SAVE) -> str:
"""Builds a chart for duration of positions.
return:
str: The filepath of the built chart.
"""
return:
str: The filepath of the built chart.
"""
rows = 1
cols = 1

Expand Down
6 changes: 3 additions & 3 deletions src/StockBench/gui/studio/strategy_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def on_save_btn_clicked(self):
self.on_save_as_btn_clicked()

def on_save_as_btn_clicked(self):
fileName, _ = QFileDialog.getSaveFileName(self, "QFileDialog.getSaveFileName()", "", "JSON (*.json)")
if fileName is not None and fileName != '':
file_name, _ = QFileDialog.getSaveFileName(self, "QFileDialog.getSaveFileName()", "", "JSON (*.json)")
if file_name is not None and file_name != '':
# only save the file if the user picked a location
self.__save_json_file(fileName)
self.__save_json_file(file_name)

def __set_geometry(self, config_pos, config_width):
# place the strategy studio to the right of the config window
Expand Down
24 changes: 4 additions & 20 deletions src/StockBench/indicators/ema/trigger.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import statistics
from StockBench.indicator.trigger import Trigger
from StockBench.indicator.exceptions import StrategyIndicatorError
from StockBench.simulation_data.data_manager import DataManager
from StockBench.indicators.sma.sma import SMATrigger

log = logging.getLogger()

Expand Down Expand Up @@ -75,17 +75,17 @@ def __add_ema(self, length: int, data_manager: DataManager):

price_data = data_manager.get_column_data(data_manager.CLOSE)

ema_values = EMATrigger.__calculate_ema(length, price_data)
ema_values = EMATrigger.calculate_ema(length, price_data)

data_manager.add_column(column_title, ema_values)

@staticmethod
def __calculate_ema(length: int, price_data: list) -> list:
def calculate_ema(length: int, price_data: list) -> list:
"""Calculates the EMA values for a list of price values"""
k = 2 / (length + 1)

# get the initial ema value (uses sma of length days)
previous_ema = EMATrigger.__calculate_sma(length, price_data[0:length])[-1]
previous_ema = SMATrigger.calculate_sma(length, price_data[0:length])[-1]

ema_values = []
for i in range(len(price_data)):
Expand All @@ -96,19 +96,3 @@ def __calculate_ema(length: int, price_data: list) -> list:
ema_values.append(ema_point)
previous_ema = ema_point
return ema_values

@staticmethod
def __calculate_sma(length: int, price_data: list) -> list:
"""Calculates the SMA values for a list of price values."""
price_values = []
sma_values = []
for day in price_data:
if len(price_values) < length:
price_values.append(float(day))
else:
price_values.pop(0)
sma_values.pop(0)
price_values.append(float(day))
avg = round(statistics.mean(price_values), 3)
sma_values.append(avg)
return sma_values
46 changes: 5 additions & 41 deletions src/StockBench/indicators/macd/trigger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import statistics
from StockBench.constants import *
from StockBench.indicator.trigger import Trigger
from StockBench.indicator.exceptions import StrategyIndicatorError
from StockBench.simulation_data.data_manager import DataManager
from StockBench.position.position import Position
from StockBench.indicators.ema.ema import EMATrigger

log = logging.getLogger()

Expand Down Expand Up @@ -43,7 +42,7 @@ def add_to_data(self, rule_key, rule_value, side, data_manager):

price_data = data_manager.get_column_data(data_manager.CLOSE)

data_manager.add_column(self.indicator_symbol, self.__calculate_macd(price_data))
data_manager.add_column(self.indicator_symbol, self.calculate_macd(price_data))

def check_trigger(self, rule_key, rule_value, data_manager, position, current_day_index) -> bool:
"""Trigger logic for EMA.
Expand All @@ -69,11 +68,11 @@ def check_trigger(self, rule_key, rule_value, data_manager, position, current_da

return Trigger.basic_trigger_check(indicator_value, operator, trigger_value)

def __calculate_macd(self, price_data: list) -> list:
def calculate_macd(self, price_data: list) -> list:
"""Calculate MACD values for a list of price values"""
large_ema_length_values = MACDTrigger.__calculate_ema(self.LARGE_EMA_LENGTH, price_data)
large_ema_length_values = EMATrigger.calculate_ema(self.LARGE_EMA_LENGTH, price_data)

small_ema_length_values = MACDTrigger.__calculate_ema(self.SMALL_EMA_LENGTH, price_data)
small_ema_length_values = EMATrigger.calculate_ema(self.SMALL_EMA_LENGTH, price_data)

if len(large_ema_length_values) != len(small_ema_length_values):
raise StrategyIndicatorError(f'{self.indicator_symbol} value lists for {self.indicator_symbol} must be the '
Expand All @@ -89,38 +88,3 @@ def __calculate_macd(self, price_data: list) -> list:
macd_values.append(round(small_ema_length_values[i] - large_ema_length_values[i], 3))

return macd_values

@staticmethod
def __calculate_ema(length: int, price_data: list) -> list:
"""Calculates the EMA values for a list of price values."""
k = 2 / (length + 1)

previous_ema = MACDTrigger.__calculate_sma(length, price_data[0:length])[-1]

ema_values = []
for i in range(len(price_data)):
if i < length:
ema_values.append(None)
else:
ema = round((k * (float(price_data[i]) - previous_ema)) + previous_ema, 3)
ema_values.append(ema)
previous_ema = ema
return ema_values

@staticmethod
def __calculate_sma(length: int, price_data: list) -> list:
"""Calculates the SMA values for a list of price values."""
price_values = []
sma_values = []
all_sma_values = []
for element in price_data:
if len(price_values) < length:
price_values.append(float(element))
else:
price_values.pop(0)
sma_values.pop(0)
price_values.append(float(element))
avg = round(statistics.mean(price_values), 3)
sma_values.append(avg)
all_sma_values.append(avg)
return all_sma_values
4 changes: 2 additions & 2 deletions src/StockBench/indicators/rsi/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __add_rsi_column(self, length: int, data_manager: DataManager):

price_data = data_manager.get_column_data(data_manager.CLOSE)

rsi_values = RSITrigger.__calculate_rsi(length, price_data)
rsi_values = RSITrigger.calculate_rsi(length, price_data)

data_manager.add_column(self.indicator_symbol, rsi_values)

@staticmethod
def __calculate_rsi(length: int, price_data: list) -> list:
def calculate_rsi(length: int, price_data: list) -> list:
"""Calculate the RSI values for a list of price values."""
first_day_value = 0
gain = []
Expand Down
5 changes: 2 additions & 3 deletions src/StockBench/indicators/sma/trigger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import statistics
from StockBench.constants import *
from StockBench.indicator.trigger import Trigger
from StockBench.indicator.exceptions import StrategyIndicatorError
from StockBench.simulation_data.data_manager import DataManager
Expand Down Expand Up @@ -77,12 +76,12 @@ def __add_sma(self, length: int, data_manager: DataManager):

price_data = data_manager.get_column_data(data_manager.CLOSE)

sma_values = SMATrigger.__calculate_sma(length, price_data)
sma_values = SMATrigger.calculate_sma(length, price_data)

data_manager.add_column(column_title, sma_values)

@staticmethod
def __calculate_sma(length: int, price_data: list) -> list:
def calculate_sma(length: int, price_data: list) -> list:
"""Calculates the SMA values for a list of price values."""
price_values = []
sma_values = []
Expand Down
5 changes: 2 additions & 3 deletions src/StockBench/indicators/stochastic/trigger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from StockBench.constants import *
from StockBench.indicator.trigger import Trigger
from StockBench.indicator.exceptions import StrategyIndicatorError
from StockBench.simulation_data.data_manager import DataManager
from StockBench.position.position import Position

Expand Down Expand Up @@ -85,12 +84,12 @@ def __add_stochastic_column(self, length: int, data_manager: DataManager):
low_data = data_manager.get_column_data(data_manager.LOW)
close_data = data_manager.get_column_data(data_manager.CLOSE)

stochastic_values = StochasticTrigger.__stochastic_oscillator(length, high_data, low_data, close_data)
stochastic_values = StochasticTrigger.stochastic_oscillator(length, high_data, low_data, close_data)

data_manager.add_column(self.indicator_symbol, stochastic_values)

@staticmethod
def __stochastic_oscillator(length: int, high_data: list, low_data: list, close_data: list) -> list:
def stochastic_oscillator(length: int, high_data: list, low_data: list, close_data: list) -> list:
"""Calculate the stochastic values for a list of price values."""
past_length_days_high = []
past_length_days_low = []
Expand Down

0 comments on commit 8cb24de

Please sign in to comment.