Skip to content

Compare_agent.py #6

@vicks4u

Description

@vicks4u

"""
RL + MetaTrader5 trading bot template

  • Train with historical data (PPO from stable-baselines3)
  • Optionally execute trades via MetaTrader5 (set LIVE=True to enable)
    CAVEAT: This is an educational template. Backtest & paper-trade first.
    """

import time
import numpy as np
import pandas as pd
import gym
from gym import spaces
import MetaTrader5 as mt5
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import CheckpointCallback

-------------------------

USER CONFIG

-------------------------

SYMBOL = "EURUSD"
TIMEFRAME = mt5.TIMEFRAME_M5 # 5 minute bars
LOOKBACK = 50 # observation window (bars)
START_POS = 0 # for historical fetch offset
LOT_SIZE = 0.01 # trade lot size
LIVE = False # <-- Set to True only after full testing (demo account first!)
MODEL_PATH = "ppo_mt5_model"
TRAIN_TIMESTEPS = 20000 # adjust as you like

-------------------------

-------------------------

Helper: connect to MT5

-------------------------

def mt5_connect():
if not mt5.initialize():
raise RuntimeError(f"MT5 initialize() failed, error={mt5.last_error()}")
info = mt5.terminal_info()
if info is None:
raise RuntimeError("Failed to get terminal info after initialize()")
print("MT5 terminal initialized:", info.product)
# Ensure symbol is available
if not mt5.symbol_select(SYMBOL, True):
raise RuntimeError(f"Failed to select symbol {SYMBOL}")
return True

def mt5_shutdown():
mt5.shutdown()

-------------------------

Get historical OHLCV

-------------------------

def fetch_bars(symbol, timeframe, n_bars):
# copy_rates_from_pos returns numpy array with fields: time, open, high, low, close, tick_volume, ...
rates = mt5.copy_rates_from_pos(symbol, timeframe, START_POS, n_bars)
if rates is None:
raise RuntimeError(f"Failed to fetch rates for {symbol}: {mt5.last_error()}")
df = pd.DataFrame(rates)
df['time'] = pd.to_datetime(df['time'], unit='s')
return df

-------------------------

Simple trading Gym env

-------------------------

class MT5TradingEnv(gym.Env):
"""
Observation: last LOOKBACK closes normalized + current position (0/1/-1)
Actions: 0=hold, 1=buy (long), 2=sell (short/close long)
Reward: change in account equity approximated by price moves * position
NOTE: Simplified; this is a research template, not production-ready.
"""
def init(self, df: pd.DataFrame, lookback=LOOKBACK):
super(MT5TradingEnv, self).init()
self.df = df.reset_index(drop=True)
self.lookback = lookback
self.ptr = lookback # current index in df
self.position = 0 # -1 short, 0 flat, 1 long
self.entry_price = 0.0
# Observations: lookback closes (normalized) + position
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(lookback + 1,), dtype=np.float32)
# Actions: hold(0), buy(1), sell(2)
self.action_space = spaces.Discrete(3)

def _get_obs(self):
    closes = self.df.loc[self.ptr - self.lookback:self.ptr - 1, "close"].values.astype(np.float32)
    # normalize closes by dividing by last close
    norm = closes / (closes[-1] + 1e-9) - 1.0
    obs = np.concatenate([norm, np.array([float(self.position)])], axis=0)
    return obs

def reset(self):
    self.ptr = self.lookback
    self.position = 0
    self.entry_price = 0.0
    return self._get_obs()

def step(self, action):
    done = False
    reward = 0.0
    price = float(self.df.loc[self.ptr, "close"])
    # Action logic
    if action == 1:  # buy
        if self.position == 0:
            self.position = 1
            self.entry_price = price
        elif self.position == -1:
            # close short and go long
            reward += (self.entry_price - price)  # profit from short
            self.position = 1
            self.entry_price = price
    elif action == 2:  # sell
        if self.position == 0:
            self.position = -1
            self.entry_price = price
        elif self.position == 1:
            reward += (price - self.entry_price)  # profit from long
            self.position = -1
            self.entry_price = price
    # Move pointer
    self.ptr += 1
    if self.ptr >= len(self.df):
        done = True
    else:
        # reward can also be shaped by unrealized pnl:
        next_price = float(self.df.loc[self.ptr, "close"])
        unrealized = 0.0
        if self.position == 1:
            unrealized = next_price - self.entry_price
        elif self.position == -1:
            unrealized = self.entry_price - next_price
        # small per-step reward = unrealized PnL scaled
        reward += unrealized * 0.1

    obs = self._get_obs() if not done else np.zeros(self.observation_space.shape, dtype=np.float32)
    info = {"ptr": self.ptr}
    return obs, float(reward), done, info

-------------------------

Order helpers

-------------------------

def send_order(symbol, action, lot=LOT_SIZE, deviation=20):
"""
action: 1=buy, 2=sell
This function sends a ORDER_TYPE_BUY / ORDER_TYPE_SELL market order.
Basic error checking included. For production you need more robust code.
"""
price = mt5.symbol_info_tick(symbol).ask if action == 1 else mt5.symbol_info_tick(symbol).bid
request = {
"action": mt5.TRADE_ACTION_DEAL,
"symbol": symbol,
"volume": float(lot),
"type": mt5.ORDER_TYPE_BUY if action == 1 else mt5.ORDER_TYPE_SELL,
"price": float(price),
"deviation": deviation,
"magic": 234000,
"comment": "RL-bot",
"type_filling": mt5.ORDER_FILLING_IOC,
}
result = mt5.order_send(request)
return result

-------------------------

Main: training flow

-------------------------

def train_agent():
mt5_connect()
# fetch historical bars
n_bars = 5000
df = fetch_bars(SYMBOL, TIMEFRAME, n_bars)
print(f"Fetched {len(df)} bars for {SYMBOL}")
# Create env
env = DummyVecEnv([lambda: MT5TradingEnv(df, lookback=LOOKBACK)])
# model
model = PPO("MlpPolicy", env, verbose=1)
# save checkpoints
cb = CheckpointCallback(save_freq=5000, save_path="./logs/", name_prefix="ppo_mt5")
model.learn(total_timesteps=TRAIN_TIMESTEPS, callback=cb)
model.save(MODEL_PATH)
mt5_shutdown()
print("Training complete, model saved to", MODEL_PATH)

-------------------------

Real-time execution loop (paper/live)

-------------------------

def run_live_loop(model_path=MODEL_PATH, poll_seconds=5):
mt5_connect()
model = PPO.load(model_path)
print("Loaded model:", model_path)
# We'll maintain a small in-memory buffer of recent bars
n_history = LOOKBACK + 10
df = fetch_bars(SYMBOL, TIMEFRAME, n_history)
# pointer is at last bar
while True:
try:
latest = fetch_bars(SYMBOL, TIMEFRAME, 1)
if latest['time'].iloc[-1] > df['time'].iloc[-1]:
# append new bar
df = pd.concat([df, latest]).reset_index(drop=True)
if len(df) > n_history:
df = df.iloc[-n_history:].reset_index(drop=True)
# Build an env instance for this single-step decision
env = MT5TradingEnv(df, lookback=LOOKBACK)
obs = env.reset()
action, _states = model.predict(obs, deterministic=True)
print(f"[{pd.to_datetime('now')}] Action: {action} | Price: {df['close'].iloc[-1]}")
# Send order if LIVE
if LIVE:
if int(action) == 1:
res = send_order(SYMBOL, 1)
print("Order send result:", res)
elif int(action) == 2:
res = send_order(SYMBOL, 2)
print("Order send result:", res)
else:
# paper-trade: just log what would happen
print("LIVE=False -> paper trade logged only")
else:
# no new bar yet
pass
time.sleep(poll_seconds)
except KeyboardInterrupt:
print("Stopping live loop (KeyboardInterrupt).")
break
except Exception as e:
print("Exception in live loop:", str(e))
time.sleep(5)
mt5_shutdown()

-------------------------

If run as script

-------------------------

if name == "main":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["train", "run"], default="train")
args = parser.parse_args()
if args.mode == "train":
print("Starting training...")
train_agent()
elif args.mode == "run":
print("Starting live/paper run loop...")
run_live_loop()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions