-
Notifications
You must be signed in to change notification settings - Fork 3.1k
RL backtest with simulator #1299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
08f725c
1907372
22cb8ee
321691d
35a19aa
773188b
e8b4165
2a1d838
c82d05e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,24 +2,26 @@ | |
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import argparse | ||
import copy | ||
import pickle | ||
import sys | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import Optional, Tuple, Union | ||
from typing import List, Literal, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import torch | ||
from joblib import Parallel, delayed | ||
|
||
from qlib.backtest import collect_data_loop, get_strategy_executor | ||
from qlib.backtest.decision import TradeRangeByTime | ||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime | ||
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor | ||
from qlib.backtest.high_performance_ds import BaseOrderIndicator | ||
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile | ||
from qlib.rl.contrib.utils import read_order_file | ||
from qlib.rl.data.integration import init_qlib | ||
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution | ||
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper | ||
|
||
|
||
|
@@ -41,7 +43,7 @@ def _get_multi_level_executor_config( | |
} | ||
|
||
freqs = list(strategy_config.keys()) | ||
freqs.sort(key=lambda x: pd.Timedelta(x)) | ||
freqs.sort(key=pd.Timedelta) | ||
for freq in freqs: | ||
executor_config = { | ||
"class": "NestedExecutor", | ||
|
@@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: | |
# HACK: for qlib v0.8 | ||
value_dict = value_dict.to_series() | ||
try: | ||
value_dict = {k: v for k, v in value_dict.items()} | ||
value_dict = copy.deepcopy(value_dict) | ||
if value_dict["ffr"].empty: | ||
continue | ||
except Exception: | ||
|
@@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: | |
return records | ||
|
||
|
||
def _generate_report(decisions: list, report_dict: dict) -> dict: | ||
# TODO: there should be richer annotation for the input (e.g. report) and the returned report | ||
# TODO: For example, @ dataclass with typed fields and detailed docstrings. | ||
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict: | ||
"""Generate backtest reports | ||
|
||
Parameters | ||
---------- | ||
decisions: | ||
List of trade decisions. | ||
report_indicators | ||
List of indicator reports. | ||
Returns | ||
------- | ||
|
||
""" | ||
indicator_dict = defaultdict(list) | ||
indicator_his = defaultdict(list) | ||
for report_indicator in report_indicators: | ||
for key, value in report_indicator.items(): | ||
if key.endswith("_obj"): | ||
indicator_his[key].append(value.order_indicator_his) | ||
else: | ||
indicator_dict[key].append(value) | ||
|
||
report = {} | ||
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")]) | ||
for key in ["1minute", "5minute", "30minute", "1day"]: | ||
if key not in report_dict["indicator"]: | ||
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")]) | ||
for key in ["1min", "5min", "30min", "1day"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I hard-coded this to quickly run through the experiments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is part of the following issue mentioned by you-n-g before. I will redesign the entire logic in later PRs.
|
||
if key not in indicator_dict: | ||
continue | ||
report[key] = report_dict["indicator"][key] | ||
report[key + "_obj"] = _convert_indicator_to_dataframe( | ||
report_dict["indicator"][key + "_obj"].order_indicator_his | ||
) | ||
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"]) | ||
|
||
report[key] = pd.concat(indicator_dict[key]) | ||
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]]) | ||
|
||
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"]) | ||
if len(cur_details) > 0: | ||
cur_details.pop("freq") | ||
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer") | ||
if "1minute" in report_dict["report"]: | ||
report["simulator"] = report_dict["report"]["1minute"][0] | ||
|
||
return report | ||
|
||
|
||
def single( | ||
def single_with_simulator( | ||
backtest_config: dict, | ||
orders: pd.DataFrame, | ||
split: str = "stock", | ||
split: Literal["stock", "day"] = "stock", | ||
cash_limit: float = None, | ||
generate_report: bool = False, | ||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be some docstring for this function. |
||
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day. | ||
A new simulator will be created and used for every single-day order. | ||
|
||
Parameters | ||
---------- | ||
backtest_config: | ||
Backtest config | ||
orders: | ||
Orders to be executed. Example format: | ||
datetime instrument amount direction | ||
0 2020-06-01 INST 600.0 0 | ||
1 2020-06-02 INST 700.0 1 | ||
... | ||
split | ||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. | ||
cash_limit | ||
Limitation of cash. | ||
generate_report | ||
Whether to generate reports. | ||
|
||
Returns | ||
------- | ||
If generate_report is True, return execution records and the generated report. Otherwise, return only records. | ||
""" | ||
if split == "stock": | ||
stock_id = orders.iloc[0].instrument | ||
init_qlib(backtest_config["qlib"], part=stock_id) | ||
else: | ||
day = orders.iloc[0].datetime | ||
init_qlib(backtest_config["qlib"], part=day) | ||
|
||
stocks = orders.instrument.unique().tolist() | ||
|
||
reports = [] | ||
decisions = [] | ||
for _, row in orders.iterrows(): | ||
date = pd.Timestamp(row["datetime"]) | ||
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day) | ||
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day) | ||
order = Order( | ||
stock_id=row["instrument"], | ||
amount=row["amount"], | ||
direction=OrderDir(row["direction"]), | ||
start_time=start_time, | ||
end_time=end_time, | ||
) | ||
|
||
executor_config = _get_multi_level_executor_config( | ||
strategy_config=backtest_config["strategies"], | ||
cash_limit=cash_limit, | ||
generate_report=generate_report, | ||
) | ||
|
||
exchange_config = copy.deepcopy(backtest_config["exchange"]) | ||
exchange_config.update( | ||
{ | ||
"codes": stocks, | ||
"freq": "1min", | ||
} | ||
) | ||
|
||
simulator = SingleAssetOrderExecution( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this simulator generate reports? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The simulator will execute one hidden step when it is created. When it is used in training, it will pause at the first |
||
order=order, | ||
executor_config=executor_config, | ||
exchange_config=exchange_config, | ||
qlib_config=None, | ||
cash_limit=None, | ||
backtest_mode=True, | ||
) | ||
|
||
reports.append(simulator.report_dict) | ||
decisions += simulator.decisions | ||
|
||
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()} | ||
records = _convert_indicator_to_dataframe(indicator) | ||
assert records is None or not np.isnan(records["ffr"]).any() | ||
|
||
if generate_report: | ||
report = _generate_report(decisions, [report["indicator"] for report in reports]) | ||
|
||
if split == "stock": | ||
stock_id = orders.iloc[0].instrument | ||
report = {stock_id: report} | ||
else: | ||
day = orders.iloc[0].datetime | ||
report = {day: report} | ||
|
||
return records, report | ||
else: | ||
return records | ||
|
||
|
||
def single_with_collect_data_loop( | ||
backtest_config: dict, | ||
orders: pd.DataFrame, | ||
split: Literal["stock", "day"] = "stock", | ||
cash_limit: float = None, | ||
generate_report: bool = False, | ||
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]: | ||
"""Run backtest in a single thread with collect_data_loop. | ||
|
||
Parameters | ||
---------- | ||
backtest_config: | ||
Backtest config | ||
orders: | ||
Orders to be executed. Example format: | ||
datetime instrument amount direction | ||
0 2020-06-01 INST 600.0 0 | ||
1 2020-06-02 INST 700.0 1 | ||
... | ||
split | ||
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date. | ||
cash_limit | ||
Limitation of cash. | ||
generate_report | ||
Whether to generate reports. | ||
|
||
Returns | ||
------- | ||
If generate_report is True, return execution records and the generated report. Otherwise, return only records. | ||
""" | ||
|
||
if split == "stock": | ||
stock_id = orders.iloc[0].instrument | ||
init_qlib(backtest_config["qlib"], part=stock_id) | ||
|
@@ -127,7 +274,7 @@ def single( | |
trade_end_time = orders["datetime"].max() | ||
stocks = orders.instrument.unique().tolist() | ||
|
||
top_strategy_config = { | ||
strategy_config = { | ||
"class": "FileOrderStrategy", | ||
"module_path": "qlib.contrib.strategy.rule_strategy", | ||
"kwargs": { | ||
|
@@ -139,14 +286,14 @@ def single( | |
}, | ||
} | ||
|
||
top_executor_config = _get_multi_level_executor_config( | ||
executor_config = _get_multi_level_executor_config( | ||
strategy_config=backtest_config["strategies"], | ||
cash_limit=cash_limit, | ||
generate_report=generate_report, | ||
) | ||
|
||
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"]) | ||
tmp_backtest_config.update( | ||
exchange_config = copy.deepcopy(backtest_config["exchange"]) | ||
exchange_config.update( | ||
{ | ||
"codes": stocks, | ||
"freq": "1min", | ||
|
@@ -156,11 +303,11 @@ def single( | |
strategy, executor = get_strategy_executor( | ||
start_time=pd.Timestamp(trade_start_time), | ||
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1), | ||
strategy=top_strategy_config, | ||
executor=top_executor_config, | ||
strategy=strategy_config, | ||
executor=executor_config, | ||
benchmark=None, | ||
account=cash_limit if cash_limit is not None else int(1e12), | ||
exchange_kwargs=tmp_backtest_config, | ||
exchange_kwargs=exchange_config, | ||
pos_type="Position" if cash_limit is not None else "InfPosition", | ||
) | ||
_set_env_for_all_strategy(executor=executor) | ||
|
@@ -172,7 +319,7 @@ def single( | |
assert records is None or not np.isnan(records["ffr"]).any() | ||
|
||
if generate_report: | ||
report = _generate_report(decisions, report_dict) | ||
report = _generate_report(decisions, [report_dict["indicator"]]) | ||
if split == "stock": | ||
stock_id = orders.iloc[0].instrument | ||
report = {stock_id: report} | ||
|
@@ -184,7 +331,7 @@ def single( | |
return records | ||
|
||
|
||
def backtest(backtest_config: dict) -> pd.DataFrame: | ||
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame: | ||
order_df = read_order_file(backtest_config["order_file"]) | ||
|
||
cash_limit = backtest_config["exchange"].pop("cash_limit") | ||
|
@@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame: | |
stock_pool = order_df["instrument"].unique().tolist() | ||
stock_pool.sort() | ||
|
||
single = single_with_simulator if with_simulator else single_with_collect_data_loop | ||
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} | ||
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 | ||
res = Parallel(**mp_config)( | ||
|
@@ -227,5 +375,12 @@ def backtest(backtest_config: dict) -> pd.DataFrame: | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
warnings.filterwarnings("ignore", category=RuntimeWarning) | ||
|
||
path = sys.argv[1] | ||
backtest(get_backtest_config_fromfile(path)) | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") | ||
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") | ||
args = parser.parse_args() | ||
|
||
backtest( | ||
backtest_config=get_backtest_config_fromfile(args.config_path), | ||
with_simulator=args.use_simulator, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add some explanations on why (in what scenarios) we need this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.