Skip to content

Commit 216a8ec

Browse files
authored
RL backtest with simulator (#1299)
* RL backtest with simulator * Minor modification in init_qlib * Cherry pick PR 1302 * Resolve PR comments * Fix missing data processing * Minor bugfix * Add TODOs and docs * Add a comment
1 parent 54928e9 commit 216a8ec

File tree

11 files changed

+354
-92
lines changed

11 files changed

+354
-92
lines changed

qlib/backtest/decision.py

+15
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,18 @@ def __repr__(self) -> str:
576576
f"trade_range: {self.trade_range}; "
577577
f"order_list[{len(self.order_list)}]"
578578
)
579+
580+
581+
class TradeDecisionWithDetails(TradeDecisionWO):
582+
"""Decision with detail information. Detail information is used to generate execution reports.
583+
"""
584+
def __init__(
585+
self,
586+
order_list: List[Order],
587+
strategy: BaseStrategy,
588+
trade_range: Optional[Tuple[int, int]] = None,
589+
details: Optional[Any] = None,
590+
) -> None:
591+
super().__init__(order_list, strategy, trade_range)
592+
593+
self.details = details

qlib/rl/contrib/backtest.py

+184-29
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,26 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5+
import argparse
56
import copy
67
import pickle
7-
import sys
8+
from collections import defaultdict
89
from pathlib import Path
9-
from typing import Optional, Tuple, Union
10+
from typing import List, Literal, Optional, Tuple, Union
1011

1112
import numpy as np
1213
import pandas as pd
1314
import torch
1415
from joblib import Parallel, delayed
1516

1617
from qlib.backtest import collect_data_loop, get_strategy_executor
17-
from qlib.backtest.decision import TradeRangeByTime
18+
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir, TradeRangeByTime
1819
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
1920
from qlib.backtest.high_performance_ds import BaseOrderIndicator
2021
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
2122
from qlib.rl.contrib.utils import read_order_file
2223
from qlib.rl.data.integration import init_qlib
24+
from qlib.rl.order_execution.simulator_qlib import SingleAssetOrderExecution
2325
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
2426

2527

@@ -41,7 +43,7 @@ def _get_multi_level_executor_config(
4143
}
4244

4345
freqs = list(strategy_config.keys())
44-
freqs.sort(key=lambda x: pd.Timedelta(x))
46+
freqs.sort(key=pd.Timedelta)
4547
for freq in freqs:
4648
executor_config = {
4749
"class": "NestedExecutor",
@@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
7375
# HACK: for qlib v0.8
7476
value_dict = value_dict.to_series()
7577
try:
76-
value_dict = {k: v for k, v in value_dict.items()}
78+
value_dict = copy.deepcopy(value_dict)
7779
if value_dict["ffr"].empty:
7880
continue
7981
except Exception:
@@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
9092
return records
9193

9294

93-
def _generate_report(decisions: list, report_dict: dict) -> dict:
95+
# TODO: there should be richer annotation for the input (e.g. report) and the returned report
96+
# TODO: For example, @ dataclass with typed fields and detailed docstrings.
97+
def _generate_report(decisions: List[BaseTradeDecision], report_indicators: List[dict]) -> dict:
98+
"""Generate backtest reports
99+
100+
Parameters
101+
----------
102+
decisions:
103+
List of trade decisions.
104+
report_indicators
105+
List of indicator reports.
106+
Returns
107+
-------
108+
109+
"""
110+
indicator_dict = defaultdict(list)
111+
indicator_his = defaultdict(list)
112+
for report_indicator in report_indicators:
113+
for key, value in report_indicator.items():
114+
if key.endswith("_obj"):
115+
indicator_his[key].append(value.order_indicator_his)
116+
else:
117+
indicator_dict[key].append(value)
118+
94119
report = {}
95-
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
96-
for key in ["1minute", "5minute", "30minute", "1day"]:
97-
if key not in report_dict["indicator"]:
120+
decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")])
121+
for key in ["1min", "5min", "30min", "1day"]:
122+
if key not in indicator_dict:
98123
continue
99-
report[key] = report_dict["indicator"][key]
100-
report[key + "_obj"] = _convert_indicator_to_dataframe(
101-
report_dict["indicator"][key + "_obj"].order_indicator_his
102-
)
103-
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
124+
125+
report[key] = pd.concat(indicator_dict[key])
126+
report[key + "_obj"] = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key + "_obj"]])
127+
128+
cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"])
104129
if len(cur_details) > 0:
105130
cur_details.pop("freq")
106131
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
107-
if "1minute" in report_dict["report"]:
108-
report["simulator"] = report_dict["report"]["1minute"][0]
132+
109133
return report
110134

111135

112-
def single(
136+
def single_with_simulator(
113137
backtest_config: dict,
114138
orders: pd.DataFrame,
115-
split: str = "stock",
139+
split: Literal["stock", "day"] = "stock",
116140
cash_limit: float = None,
117141
generate_report: bool = False,
118142
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
143+
"""Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
144+
A new simulator will be created and used for every single-day order.
145+
146+
Parameters
147+
----------
148+
backtest_config:
149+
Backtest config
150+
orders:
151+
Orders to be executed. Example format:
152+
datetime instrument amount direction
153+
0 2020-06-01 INST 600.0 0
154+
1 2020-06-02 INST 700.0 1
155+
...
156+
split
157+
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
158+
cash_limit
159+
Limitation of cash.
160+
generate_report
161+
Whether to generate reports.
162+
163+
Returns
164+
-------
165+
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
166+
"""
167+
if split == "stock":
168+
stock_id = orders.iloc[0].instrument
169+
init_qlib(backtest_config["qlib"], part=stock_id)
170+
else:
171+
day = orders.iloc[0].datetime
172+
init_qlib(backtest_config["qlib"], part=day)
173+
174+
stocks = orders.instrument.unique().tolist()
175+
176+
reports = []
177+
decisions = []
178+
for _, row in orders.iterrows():
179+
date = pd.Timestamp(row["datetime"])
180+
start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day)
181+
end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day)
182+
order = Order(
183+
stock_id=row["instrument"],
184+
amount=row["amount"],
185+
direction=OrderDir(row["direction"]),
186+
start_time=start_time,
187+
end_time=end_time,
188+
)
189+
190+
executor_config = _get_multi_level_executor_config(
191+
strategy_config=backtest_config["strategies"],
192+
cash_limit=cash_limit,
193+
generate_report=generate_report,
194+
)
195+
196+
exchange_config = copy.deepcopy(backtest_config["exchange"])
197+
exchange_config.update(
198+
{
199+
"codes": stocks,
200+
"freq": "1min",
201+
}
202+
)
203+
204+
simulator = SingleAssetOrderExecution(
205+
order=order,
206+
executor_config=executor_config,
207+
exchange_config=exchange_config,
208+
qlib_config=None,
209+
cash_limit=None,
210+
backtest_mode=True,
211+
)
212+
213+
reports.append(simulator.report_dict)
214+
decisions += simulator.decisions
215+
216+
indicator = {k: v for report in reports for k, v in report["indicator"]["1day_obj"].order_indicator_his.items()}
217+
records = _convert_indicator_to_dataframe(indicator)
218+
assert records is None or not np.isnan(records["ffr"]).any()
219+
220+
if generate_report:
221+
report = _generate_report(decisions, [report["indicator"] for report in reports])
222+
223+
if split == "stock":
224+
stock_id = orders.iloc[0].instrument
225+
report = {stock_id: report}
226+
else:
227+
day = orders.iloc[0].datetime
228+
report = {day: report}
229+
230+
return records, report
231+
else:
232+
return records
233+
234+
235+
def single_with_collect_data_loop(
236+
backtest_config: dict,
237+
orders: pd.DataFrame,
238+
split: Literal["stock", "day"] = "stock",
239+
cash_limit: float = None,
240+
generate_report: bool = False,
241+
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
242+
"""Run backtest in a single thread with collect_data_loop.
243+
244+
Parameters
245+
----------
246+
backtest_config:
247+
Backtest config
248+
orders:
249+
Orders to be executed. Example format:
250+
datetime instrument amount direction
251+
0 2020-06-01 INST 600.0 0
252+
1 2020-06-02 INST 700.0 1
253+
...
254+
split
255+
Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
256+
cash_limit
257+
Limitation of cash.
258+
generate_report
259+
Whether to generate reports.
260+
261+
Returns
262+
-------
263+
If generate_report is True, return execution records and the generated report. Otherwise, return only records.
264+
"""
265+
119266
if split == "stock":
120267
stock_id = orders.iloc[0].instrument
121268
init_qlib(backtest_config["qlib"], part=stock_id)
@@ -127,7 +274,7 @@ def single(
127274
trade_end_time = orders["datetime"].max()
128275
stocks = orders.instrument.unique().tolist()
129276

130-
top_strategy_config = {
277+
strategy_config = {
131278
"class": "FileOrderStrategy",
132279
"module_path": "qlib.contrib.strategy.rule_strategy",
133280
"kwargs": {
@@ -139,14 +286,14 @@ def single(
139286
},
140287
}
141288

142-
top_executor_config = _get_multi_level_executor_config(
289+
executor_config = _get_multi_level_executor_config(
143290
strategy_config=backtest_config["strategies"],
144291
cash_limit=cash_limit,
145292
generate_report=generate_report,
146293
)
147294

148-
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
149-
tmp_backtest_config.update(
295+
exchange_config = copy.deepcopy(backtest_config["exchange"])
296+
exchange_config.update(
150297
{
151298
"codes": stocks,
152299
"freq": "1min",
@@ -156,11 +303,11 @@ def single(
156303
strategy, executor = get_strategy_executor(
157304
start_time=pd.Timestamp(trade_start_time),
158305
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
159-
strategy=top_strategy_config,
160-
executor=top_executor_config,
306+
strategy=strategy_config,
307+
executor=executor_config,
161308
benchmark=None,
162309
account=cash_limit if cash_limit is not None else int(1e12),
163-
exchange_kwargs=tmp_backtest_config,
310+
exchange_kwargs=exchange_config,
164311
pos_type="Position" if cash_limit is not None else "InfPosition",
165312
)
166313
_set_env_for_all_strategy(executor=executor)
@@ -172,7 +319,7 @@ def single(
172319
assert records is None or not np.isnan(records["ffr"]).any()
173320

174321
if generate_report:
175-
report = _generate_report(decisions, report_dict)
322+
report = _generate_report(decisions, [report_dict["indicator"]])
176323
if split == "stock":
177324
stock_id = orders.iloc[0].instrument
178325
report = {stock_id: report}
@@ -184,7 +331,7 @@ def single(
184331
return records
185332

186333

187-
def backtest(backtest_config: dict) -> pd.DataFrame:
334+
def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFrame:
188335
order_df = read_order_file(backtest_config["order_file"])
189336

190337
cash_limit = backtest_config["exchange"].pop("cash_limit")
@@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
193340
stock_pool = order_df["instrument"].unique().tolist()
194341
stock_pool.sort()
195342

343+
single = single_with_simulator if with_simulator else single_with_collect_data_loop
196344
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
197345
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
198346
res = Parallel(**mp_config)(
@@ -227,5 +375,12 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
227375
warnings.filterwarnings("ignore", category=DeprecationWarning)
228376
warnings.filterwarnings("ignore", category=RuntimeWarning)
229377

230-
path = sys.argv[1]
231-
backtest(get_backtest_config_fromfile(path))
378+
parser = argparse.ArgumentParser()
379+
parser.add_argument("--config_path", type=str, required=True, help="Path to the config file")
380+
parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend")
381+
args = parser.parse_args()
382+
383+
backtest(
384+
backtest_config=get_backtest_config_fromfile(args.config_path),
385+
with_simulator=args.use_simulator,
386+
)

qlib/rl/contrib/naive_config_parser.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def parse_backtest_config(path: str) -> dict:
5353

5454
del sys.modules[tmp_module_name]
5555
else:
56-
config = yaml.safe_load(open(tmp_config_file.name))
56+
with open(tmp_config_file.name) as input_stream:
57+
config = yaml.safe_load(input_stream)
5758

5859
if "_base_" in config:
5960
base_file_name = config.pop("_base_")

qlib/rl/data/integration.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def init_qlib(qlib_config: dict, part: str = None) -> None:
8181
def _convert_to_path(path: str | Path) -> Path:
8282
return path if isinstance(path, Path) else Path(path)
8383

84-
provider_uri_map = {
85-
"day": _convert_to_path(qlib_config["provider_uri_day"]).as_posix(),
86-
"1min": _convert_to_path(qlib_config["provider_uri_1min"]).as_posix(),
87-
}
84+
provider_uri_map = {}
85+
if "provider_uri_day" in qlib_config:
86+
provider_uri_map["day"] = _convert_to_path(qlib_config["provider_uri_day"]).as_posix()
87+
if "provider_uri_1min" in qlib_config:
88+
provider_uri_map["1min"] = _convert_to_path(qlib_config["provider_uri_1min"]).as_posix()
89+
8890
qlib.init(
8991
region=REG_CN,
9092
auto_mount=False,

0 commit comments

Comments
 (0)