Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(reinforcement learning): ability to train and use an agent trained by RL #444

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
52616e6
feat(reinforcement-learning): wip: an environment for trading agent i…
yakir4123 Apr 12, 2024
561fc08
feat(jesse-rl): wip
yakir4123 Apr 14, 2024
fb9f97a
feat(reinforcement-learning): working environment + reinforce learning
yakir4123 Apr 17, 2024
8ee6b91
feat(reinforcement-learning): working environment + reinforce learning
yakir4123 Apr 18, 2024
1871d34
feat(reinforcement-learning): able to run multi-processing simulations
yakir4123 Apr 18, 2024
fb72dd3
feat(reinforcement-learning): able to run different simulations each …
yakir4123 Apr 19, 2024
c40fcb6
feat(jesse-rl): add train-config standardization.
yakir4123 Apr 19, 2024
dab0947
feat(jesse-rl): added canvas util library for plotting.
yakir4123 Apr 20, 2024
b8a47b9
Merge branch 'yakir/breakdown-simulation-function' into yakir/feat/je…
yakir4123 Apr 20, 2024
1fee52a
feat(reinforcement-learning): save the best agent and added test func…
yakir4123 Apr 21, 2024
998fbc7
feat(reinforcement-learning): make canvas works for windows.., need t…
yakir4123 Apr 26, 2024
b8de400
feat(jesse-rl): Add warmup candles capabilities
yakir4123 Apr 27, 2024
6370d1e
feat(jesse-rl): Add termination function to gym environment
yakir4123 Apr 28, 2024
a74a871
feat(jesse-rl): save agent output
yakir4123 Apr 28, 2024
cff98d3
feat(reinforcement-learning): fix save path
yakir4123 Apr 29, 2024
55790ae
feat(jesse-rl): move agent learn to subprocess, fix environment prepa…
yakir4123 Apr 29, 2024
389b831
feat(jesse-rl): move agent learn to subprocess, fix environment prepa…
yakir4123 Apr 29, 2024
d576952
feat(reinforcement-learning): fix save path & warmup candles
yakir4123 May 3, 2024
c9fdd5e
feat(reinforcement-learning): agent settings callback + load agent fo…
yakir4123 May 4, 2024
70b7b51
Merge branch 'yakir/breakdown-simulation-function' into yakir/feat/je…
yakir4123 May 4, 2024
ab2ab49
feat(reinforcement-learning): train_config is not depands on n_jobs
yakir4123 May 5, 2024
caa5460
feat(reinforce-learning): able to load and train already trained agent
yakir4123 May 5, 2024
5ed0e91
feat(reinforce-learning): works on removing gymnaisum and use the bac…
yakir4123 May 6, 2024
ebdffde
feat(reinforce-learning): fully removed environment
yakir4123 May 6, 2024
61918ff
feat(reinforce-learning): minor fix
yakir4123 May 6, 2024
85411a7
feat(reinforce-learning): delete unused code from gymnasium implement…
yakir4123 May 6, 2024
f24df42
feat(reinforce-learning): update rl
yakir4123 May 8, 2024
83af24c
remove(reinforce-learning): unused code
yakir4123 May 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
118 changes: 118 additions & 0 deletions jesse/libs/plotly_utils/canvas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Literal

import numpy as np
import plotly.graph_objects as go

from plotly.subplots import make_subplots

import io
import plotly.io as pio
from PIL import Image

# pio.kaleido.scope.mathjax = None


COLOR_FORMAT = Literal["grayscale", "rgb", "rgba"]


class Canvas:

def __init__(self, width: int = 128, height: int = 128):
self.fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
self.fig.update_layout(showlegend=False)
self.fig.update_layout(
margin=dict(l=0, r=0, t=0, b=0),
autosize=False,
width=width,
height=height,
)
self.fig.update_xaxes(showticklabels=False)
self.fig.update_yaxes(showticklabels=False)

def add_candles(self, candles: np.ndarray, count: int = -1) -> None:
self.fig.add_trace(
go.Candlestick(
open=candles[-count:, 1],
close=candles[-count:, 2],
high=candles[-count:, 3],
low=candles[-count:, 4],
increasing={'line': {'color': 'green', 'width': 1}, 'fillcolor': 'green'},
decreasing={'line': {'color': 'red', 'width': 1}, 'fillcolor': 'red'},
)
)
self.fig.update(layout_xaxis_rangeslider_visible=False)

def add_line(self, values: np.ndarray, count: int, **kwargs) -> None:
self.fig.add_trace(go.Scatter(y=values[-count:], **kwargs))

def add_area(
self,
higher_line: np.ndarray,
lower_line: np.ndarray,
count: int,
fillcolor: str,
**kwargs,
) -> None:

self.fig.add_trace(
go.Scatter(
y=higher_line[-count:],
**kwargs,
)
)
self.fig.add_trace(
go.Scatter(
y=lower_line[-count:],
fill="tonexty",
fillcolor=fillcolor,
**kwargs,
)
)

def to_array(self, format: COLOR_FORMAT = "grayscale") -> np.ndarray:
buf = io.BytesIO()
self.fig.write_image(
buf,
engine="kaleido",
format="png",
)
img = Image.open(buf)
rgba = np.asarray(img)
if format == "rgba":
return rgba
rgb = rgba[..., :3]
alpha = rgba[..., 3]
background_color = np.array([255, 255, 255]) # White background
alpha = alpha[:, :, np.newaxis] / 255.0
rgb_res = rgb * alpha + background_color * (1 - alpha)
rgb_res = np.clip(rgb_res, 0, 255).astype(np.uint8)
if format == "rgb":
return rgb_res

red, green, blue = (
rgb_res[:, :, 0],
rgb_res[:, :, 1],
rgb_res[:, :, 2],
)

# Standard weights for converting to grayscale
gray_values = 0.2989 * red + 0.5870 * green + 0.1140 * blue
return gray_values

def show(self) -> None:
buf = io.BytesIO()
self.fig.write_image(buf)
img = Image.open(buf)
img.show()

def disable_x_axis(self) -> None:
self.fig.update_xaxes(showticklabels=False)

def enable_x_axis(self) -> None:
self.fig.update_xaxes(showticklabels=True)

def disable_y_axis(self) -> None:
self.fig.update_yaxes(showticklabels=False)

def enable_y_axis(self) -> None:
self.fig.update_yaxes(showticklabels=True)
14 changes: 11 additions & 3 deletions jesse/modes/backtest_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from agilerl.algorithms.ppo import PPO

import jesse.helpers as jh
import jesse.services.metrics as stats
import jesse.services.selectors as selectors
Expand Down Expand Up @@ -215,6 +217,7 @@ def _step_simulator(
generate_equity_curve: bool = False,
generate_hyperparameters: bool = False,
generate_logs: bool = False,
agent: PPO | None = None,
) -> dict:
# In case generating logs is specifically demanded, the debug mode must be enabled.
if generate_logs:
Expand All @@ -227,7 +230,7 @@ def _step_simulator(

length = simulation_minutes_length(candles)
prepare_times_before_simulation(candles)
prepare_routes(hyperparameters)
prepare_routes(hyperparameters, agent)

# add initial balance
save_daily_portfolio_balance()
Expand Down Expand Up @@ -342,7 +345,7 @@ def prepare_times_before_simulation(candles: dict) -> None:
store.app.time = first_candles_set[0][0]


def prepare_routes(hyperparameters: dict = None) -> None:
def prepare_routes(hyperparameters: dict = None, agent: PPO | None = None) -> None:
# initiate strategies
for r in router.routes:
# if the r.strategy is str read it from file
Expand Down Expand Up @@ -382,6 +385,10 @@ def prepare_routes(hyperparameters: dict = None) -> None:
# it also injects hyperparameters into self.hp in case the route does not uses any DNAs
r.strategy._init_objects()

# override agent in case it passed in the simulation
if agent is not None:
r.strategy.agent = agent

selectors.get_position(r.exchange, r.symbol).strategy = r.strategy


Expand Down Expand Up @@ -552,6 +559,7 @@ def _skip_simulator(
generate_equity_curve: bool = False,
generate_hyperparameters: bool = False,
generate_logs: bool = False,
agent: PPO | None = None,
) -> dict:
# In case generating logs is specifically demanded, the debug mode must be enabled.
if generate_logs:
Expand All @@ -561,7 +569,7 @@ def _skip_simulator(

length = simulation_minutes_length(candles)
prepare_times_before_simulation(candles)
prepare_routes(hyperparameters)
prepare_routes(hyperparameters, agent)

# add initial balance
save_daily_portfolio_balance()
Expand Down
71 changes: 40 additions & 31 deletions jesse/research/backtest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from typing import List, Dict
import copy

from agilerl.algorithms.ppo import PPO


def backtest(
config: dict,
routes: List[Dict[str, str]],
extra_routes: List[Dict[str, str]],
candles: dict,
warmup_candles: dict = None,
generate_charts: bool = False,
generate_tradingview: bool = False,
generate_quantstats: bool = False,
generate_hyperparameters: bool = False,
generate_equity_curve: bool = False,
generate_csv: bool = False,
generate_json: bool = False,
generate_logs: bool = False,
hyperparameters: dict = None,
fast_mode: bool = False
config: dict,
routes: List[Dict[str, str]],
extra_routes: List[Dict[str, str]],
candles: dict,
warmup_candles: dict = None,
generate_charts: bool = False,
generate_tradingview: bool = False,
generate_quantstats: bool = False,
generate_hyperparameters: bool = False,
generate_equity_curve: bool = False,
generate_csv: bool = False,
generate_json: bool = False,
generate_logs: bool = False,
hyperparameters: dict = None,
fast_mode: bool = False,
agent: PPO | None = None,
) -> dict:
"""
An isolated backtest() function which is perfect for using in research, and AI training
Expand Down Expand Up @@ -67,26 +70,28 @@ def backtest(
generate_hyperparameters=generate_hyperparameters,
generate_logs=generate_logs,
fast_mode=fast_mode,
agent=agent,
)


def _isolated_backtest(
config: dict,
routes: List[Dict[str, str]],
extra_routes: List[Dict[str, str]],
candles: dict,
warmup_candles: dict = None,
run_silently: bool = True,
hyperparameters: dict = None,
generate_charts: bool = False,
generate_tradingview: bool = False,
generate_quantstats: bool = False,
generate_csv: bool = False,
generate_json: bool = False,
generate_equity_curve: bool = False,
generate_hyperparameters: bool = False,
generate_logs: bool = False,
fast_mode: bool = False,
config: dict,
routes: List[Dict[str, str]],
extra_routes: List[Dict[str, str]],
candles: dict,
warmup_candles: dict = None,
run_silently: bool = True,
hyperparameters: dict = None,
generate_charts: bool = False,
generate_tradingview: bool = False,
generate_quantstats: bool = False,
generate_csv: bool = False,
generate_json: bool = False,
generate_equity_curve: bool = False,
generate_hyperparameters: bool = False,
generate_logs: bool = False,
fast_mode: bool = False,
agent: PPO | None = None,
) -> dict:
from jesse.services.validators import validate_routes
from jesse.modes.backtest_mode import simulator
Expand Down Expand Up @@ -153,6 +158,7 @@ def _isolated_backtest(
generate_hyperparameters=generate_hyperparameters,
generate_logs=generate_logs,
fast_mode=fast_mode,
agent=agent,
)

result = {
Expand Down Expand Up @@ -181,6 +187,9 @@ def _isolated_backtest(
result['hyperparameters'] = backtest_result['hyperparameters']
if generate_logs:
result['logs'] = backtest_result['logs']
if agent is not None:
result['scores'] = store.reinforce_learning.scores()
result['experience'] = store.reinforce_learning.experience()

# reset store and config so rerunning would be flawlessly possible
reset_config()
Expand Down
Loading
Loading