-
Notifications
You must be signed in to change notification settings - Fork 48
Description
"""
Compare PPO, A2C, and DQN on MT5 trading environment.
- Each agent is trained separately
- Results (mean reward) are logged
"""
import time
import numpy as np
import pandas as pd
import gym
from gym import spaces
import MetaTrader5 as mt5
from stable_baselines3 import PPO, A2C, DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy
-------------------------
CONFIG
-------------------------
SYMBOL = "EURUSD"
TIMEFRAME = mt5.TIMEFRAME_M5
LOOKBACK = 50
TRAIN_TIMESTEPS = 20000
N_BARS = 5000
SEED = 42
-------------------------
-------------------------
MT5 Connection
-------------------------
def mt5_connect():
if not mt5.initialize():
raise RuntimeError(f"MT5 init failed: {mt5.last_error()}")
if not mt5.symbol_select(SYMBOL, True):
raise RuntimeError(f"Could not select {SYMBOL}")
def mt5_shutdown():
mt5.shutdown()
def fetch_bars(symbol, timeframe, n_bars):
rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, n_bars)
if rates is None:
raise RuntimeError(f"Failed to fetch data: {mt5.last_error()}")
df = pd.DataFrame(rates)
df['time'] = pd.to_datetime(df['time'], unit='s')
return df
-------------------------
Custom Gym Env
-------------------------
class MT5TradingEnv(gym.Env):
def init(self, df, lookback=LOOKBACK):
super().init()
self.df = df.reset_index(drop=True)
self.lookback = lookback
self.ptr = lookback
self.position = 0
self.entry_price = 0
self.observation_space = spaces.Box(low=-np.inf, high=np.inf,
shape=(lookback+1,), dtype=np.float32)
self.action_space = spaces.Discrete(3) # 0=hold, 1=buy, 2=sell
def _get_obs(self):
closes = self.df.loc[self.ptr-self.lookback:self.ptr-1, "close"].values.astype(np.float32)
norm = closes / (closes[-1] + 1e-9) - 1.0
return np.concatenate([norm, [float(self.position)]], axis=0)
def reset(self):
self.ptr = self.lookback
self.position = 0
self.entry_price = 0
return self._get_obs()
def step(self, action):
done, reward = False, 0
price = float(self.df.loc[self.ptr, "close"])
if action == 1: # buy
if self.position == 0:
self.position, self.entry_price = 1, price
elif self.position == -1:
reward += (self.entry_price - price)
self.position, self.entry_price = 1, price
elif action == 2: # sell
if self.position == 0:
self.position, self.entry_price = -1, price
elif self.position == 1:
reward += (price - self.entry_price)
self.position, self.entry_price = -1, price
self.ptr += 1
if self.ptr >= len(self.df):
done = True
else:
next_price = float(self.df.loc[self.ptr, "close"])
if self.position == 1:
reward += (next_price - self.entry_price) * 0.1
elif self.position == -1:
reward += (self.entry_price - next_price) * 0.1
obs = self._get_obs() if not done else np.zeros(self.observation_space.shape, dtype=np.float32)
return obs, float(reward), done, {}
-------------------------
Training & Evaluation
-------------------------
def run_comparison():
mt5_connect()
df = fetch_bars(SYMBOL, TIMEFRAME, N_BARS)
mt5_shutdown()
results = {}
agents = {
"PPO": PPO,
"A2C": A2C,
"DQN": DQN
}
for name, algo in agents.items():
print(f"\n=== Training {name} ===")
env = DummyVecEnv([lambda: MT5TradingEnv(df, lookback=LOOKBACK)])
model = algo("MlpPolicy", env, verbose=0, seed=SEED)
model.learn(total_timesteps=TRAIN_TIMESTEPS)
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=5)
results[name] = (mean_reward, std_reward)
print(f"{name} → mean reward: {mean_reward:.2f}, std: {std_reward:.2f}")
print("\n=== Summary ===")
for k, v in results.items():
print(f"{k}: mean {v[0]:.2f}, std {v[1]:.2f}")
if name == "main":
run_comparison()