Skip to content

Commit a41b18b

Browse files
authored
Migrate backtest logic from NT (microsoft#1263)
* Backtest migration * Minor bug fix in test * Reorganize file to avoid loop import * Fix test SAOE bug * Remove unnecessary names * Resolve PR comments; remove private classes; * Fix CI error * Resolve PR comments * Refactor data interfaces * Remove convert_instance_config and change config * Pylint issue * Pylint issue * Fix tempfile warning * Resolve PR comments * Add more comments
1 parent 2a0ec49 commit a41b18b

19 files changed

+794
-118
lines changed

qlib/backtest/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_exchange(
114114
def create_account_instance(
115115
start_time: Union[pd.Timestamp, str],
116116
end_time: Union[pd.Timestamp, str],
117-
benchmark: str,
117+
benchmark: Optional[str],
118118
account: Union[float, int, dict],
119119
pos_type: str = "Position",
120120
) -> Account:
@@ -163,7 +163,9 @@ def create_account_instance(
163163
init_cash=init_cash,
164164
position_dict=position_dict,
165165
pos_type=pos_type,
166-
benchmark_config={
166+
benchmark_config={}
167+
if benchmark is None
168+
else {
167169
"benchmark": benchmark,
168170
"start_time": start_time,
169171
"end_time": end_time,
@@ -176,7 +178,7 @@ def get_strategy_executor(
176178
end_time: Union[pd.Timestamp, str],
177179
strategy: Union[str, dict, object, Path],
178180
executor: Union[str, dict, object, Path],
179-
benchmark: str = "SH000300",
181+
benchmark: Optional[str] = "SH000300",
180182
account: Union[float, int, dict] = 1e9,
181183
exchange_kwargs: dict = {},
182184
pos_type: str = "Position",

qlib/rl/contrib/backtest.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import copy
6+
import pickle
7+
import sys
8+
from pathlib import Path
9+
from typing import Optional, Tuple, Union
10+
11+
import numpy as np
12+
import pandas as pd
13+
import torch
14+
from joblib import Parallel, delayed
15+
16+
from qlib.backtest import collect_data_loop, get_strategy_executor
17+
from qlib.backtest.decision import TradeRangeByTime
18+
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
19+
from qlib.backtest.high_performance_ds import BaseOrderIndicator
20+
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
21+
from qlib.rl.contrib.utils import read_order_file
22+
from qlib.rl.data.integration import init_qlib
23+
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper
24+
25+
26+
def _get_multi_level_executor_config(
27+
strategy_config: dict,
28+
cash_limit: float = None,
29+
generate_report: bool = False,
30+
) -> dict:
31+
executor_config = {
32+
"class": "SimulatorExecutor",
33+
"module_path": "qlib.backtest.executor",
34+
"kwargs": {
35+
"time_per_step": "1min",
36+
"verbose": False,
37+
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
38+
"generate_report": generate_report,
39+
"track_data": True,
40+
},
41+
}
42+
43+
freqs = list(strategy_config.keys())
44+
freqs.sort(key=lambda x: pd.Timedelta(x))
45+
for freq in freqs:
46+
executor_config = {
47+
"class": "NestedExecutor",
48+
"module_path": "qlib.backtest.executor",
49+
"kwargs": {
50+
"time_per_step": freq,
51+
"inner_strategy": strategy_config[freq],
52+
"inner_executor": executor_config,
53+
"track_data": True,
54+
},
55+
}
56+
57+
return executor_config
58+
59+
60+
def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
61+
if isinstance(executor, NestedExecutor):
62+
if hasattr(executor.inner_strategy, "set_env"):
63+
env = CollectDataEnvWrapper()
64+
env.reset()
65+
executor.inner_strategy.set_env(env)
66+
_set_env_for_all_strategy(executor.inner_executor)
67+
68+
69+
def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
70+
record_list = []
71+
for time, value_dict in indicator.items():
72+
if isinstance(value_dict, BaseOrderIndicator):
73+
# HACK: for qlib v0.8
74+
value_dict = value_dict.to_series()
75+
try:
76+
value_dict = {k: v for k, v in value_dict.items()}
77+
if value_dict["ffr"].empty:
78+
continue
79+
except Exception:
80+
value_dict = {k: v for k, v in value_dict.items() if k != "pa"}
81+
value_dict = pd.DataFrame(value_dict)
82+
value_dict["datetime"] = time
83+
record_list.append(value_dict)
84+
85+
if not record_list:
86+
return None
87+
88+
records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"})
89+
records = records.set_index(["instrument", "datetime"])
90+
return records
91+
92+
93+
def _generate_report(decisions: list, report_dict: dict) -> dict:
94+
report = {}
95+
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
96+
for key in ["1minute", "5minute", "30minute", "1day"]:
97+
if key not in report_dict["indicator"]:
98+
continue
99+
report[key] = report_dict["indicator"][key]
100+
report[key + "_obj"] = _convert_indicator_to_dataframe(
101+
report_dict["indicator"][key + "_obj"].order_indicator_his
102+
)
103+
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
104+
if len(cur_details) > 0:
105+
cur_details.pop("freq")
106+
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
107+
if "1minute" in report_dict["report"]:
108+
report["simulator"] = report_dict["report"]["1minute"][0]
109+
return report
110+
111+
112+
def single(
113+
backtest_config: dict,
114+
orders: pd.DataFrame,
115+
split: str = "stock",
116+
cash_limit: float = None,
117+
generate_report: bool = False,
118+
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
119+
if split == "stock":
120+
stock_id = orders.iloc[0].instrument
121+
init_qlib(backtest_config["qlib"], part=stock_id)
122+
else:
123+
day = orders.iloc[0].datetime
124+
init_qlib(backtest_config["qlib"], part=day)
125+
126+
trade_start_time = orders["datetime"].min()
127+
trade_end_time = orders["datetime"].max()
128+
stocks = orders.instrument.unique().tolist()
129+
130+
top_strategy_config = {
131+
"class": "FileOrderStrategy",
132+
"module_path": "qlib.contrib.strategy.rule_strategy",
133+
"kwargs": {
134+
"file": orders,
135+
"trade_range": TradeRangeByTime(
136+
pd.Timestamp(backtest_config["start_time"]).time(),
137+
pd.Timestamp(backtest_config["end_time"]).time(),
138+
),
139+
},
140+
}
141+
142+
top_executor_config = _get_multi_level_executor_config(
143+
strategy_config=backtest_config["strategies"],
144+
cash_limit=cash_limit,
145+
generate_report=generate_report,
146+
)
147+
148+
tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
149+
tmp_backtest_config.update(
150+
{
151+
"codes": stocks,
152+
"freq": "1min",
153+
}
154+
)
155+
156+
strategy, executor = get_strategy_executor(
157+
start_time=pd.Timestamp(trade_start_time),
158+
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
159+
strategy=top_strategy_config,
160+
executor=top_executor_config,
161+
benchmark=None,
162+
account=cash_limit if cash_limit is not None else int(1e12),
163+
exchange_kwargs=tmp_backtest_config,
164+
pos_type="Position" if cash_limit is not None else "InfPosition",
165+
)
166+
_set_env_for_all_strategy(executor=executor)
167+
168+
report_dict: dict = {}
169+
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))
170+
171+
records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
172+
assert records is None or not np.isnan(records["ffr"]).any()
173+
174+
if generate_report:
175+
report = _generate_report(decisions, report_dict)
176+
if split == "stock":
177+
stock_id = orders.iloc[0].instrument
178+
report = {stock_id: report}
179+
else:
180+
day = orders.iloc[0].datetime
181+
report = {day: report}
182+
return records, report
183+
else:
184+
return records
185+
186+
187+
def backtest(backtest_config: dict) -> pd.DataFrame:
188+
order_df = read_order_file(backtest_config["order_file"])
189+
190+
cash_limit = backtest_config["exchange"].pop("cash_limit")
191+
generate_report = backtest_config["exchange"].pop("generate_report")
192+
193+
stock_pool = order_df["instrument"].unique().tolist()
194+
stock_pool.sort()
195+
196+
mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
197+
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
198+
res = Parallel(**mp_config)(
199+
delayed(single)(
200+
backtest_config=backtest_config,
201+
orders=order_df[order_df["instrument"] == stock].copy(),
202+
split="stock",
203+
cash_limit=cash_limit,
204+
generate_report=generate_report,
205+
)
206+
for stock in stock_pool
207+
)
208+
209+
output_path = Path(backtest_config["output_dir"])
210+
if generate_report:
211+
with (output_path / "report.pkl").open("wb") as f:
212+
report = {}
213+
for r in res:
214+
report.update(r[1])
215+
pickle.dump(report, f)
216+
res = pd.concat([r[0] for r in res], 0)
217+
else:
218+
res = pd.concat(res)
219+
220+
res.to_csv(output_path / "summary.csv")
221+
return res
222+
223+
224+
if __name__ == "__main__":
225+
import warnings
226+
227+
warnings.filterwarnings("ignore", category=DeprecationWarning)
228+
warnings.filterwarnings("ignore", category=RuntimeWarning)
229+
230+
path = sys.argv[1]
231+
backtest(get_backtest_config_fromfile(path))
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
import platform
6+
import shutil
7+
import sys
8+
import tempfile
9+
from importlib import import_module
10+
11+
import yaml
12+
13+
14+
def merge_a_into_b(a: dict, b: dict) -> dict:
15+
b = b.copy()
16+
for k, v in a.items():
17+
if isinstance(v, dict) and k in b:
18+
v.pop("_delete_", False) # TODO: make this more elegant
19+
b[k] = merge_a_into_b(v, b[k])
20+
else:
21+
b[k] = v
22+
return b
23+
24+
25+
def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
26+
if not os.path.isfile(filename):
27+
raise FileNotFoundError(msg_tmpl.format(filename))
28+
29+
30+
def parse_backtest_config(path: str) -> dict:
31+
abs_path = os.path.abspath(path)
32+
check_file_exist(abs_path)
33+
34+
file_ext_name = os.path.splitext(abs_path)[1]
35+
if file_ext_name not in (".py", ".json", ".yaml", ".yml"):
36+
raise IOError("Only py/yml/yaml/json type are supported now!")
37+
38+
with tempfile.TemporaryDirectory() as tmp_config_dir:
39+
with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file:
40+
if platform.system() == "Windows":
41+
tmp_config_file.close()
42+
43+
tmp_config_name = os.path.basename(tmp_config_file.name)
44+
shutil.copyfile(abs_path, tmp_config_file.name)
45+
46+
if abs_path.endswith(".py"):
47+
tmp_module_name = os.path.splitext(tmp_config_name)[0]
48+
sys.path.insert(0, tmp_config_dir)
49+
module = import_module(tmp_module_name)
50+
sys.path.pop(0)
51+
52+
config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")}
53+
54+
del sys.modules[tmp_module_name]
55+
else:
56+
config = yaml.safe_load(open(tmp_config_file.name))
57+
58+
if "_base_" in config:
59+
base_file_name = config.pop("_base_")
60+
if not isinstance(base_file_name, list):
61+
base_file_name = [base_file_name]
62+
63+
for f in base_file_name:
64+
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
65+
config = merge_a_into_b(a=config, b=base_config)
66+
67+
return config
68+
69+
70+
def _convert_all_list_to_tuple(config: dict) -> dict:
71+
for k, v in config.items():
72+
if isinstance(v, list):
73+
config[k] = tuple(v)
74+
elif isinstance(v, dict):
75+
config[k] = _convert_all_list_to_tuple(v)
76+
return config
77+
78+
79+
def get_backtest_config_fromfile(path: str) -> dict:
80+
backtest_config = parse_backtest_config(path)
81+
82+
exchange_config_default = {
83+
"open_cost": 0.0005,
84+
"close_cost": 0.0015,
85+
"min_cost": 5.0,
86+
"trade_unit": 100.0,
87+
"cash_limit": None,
88+
"generate_report": False,
89+
}
90+
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
91+
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])
92+
93+
backtest_config_default = {
94+
"debug_single_stock": None,
95+
"debug_single_day": None,
96+
"concurrency": -1,
97+
"multiplier": 1.0,
98+
"output_dir": "outputs/",
99+
# "runtime": {},
100+
}
101+
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)
102+
103+
return backtest_config

qlib/rl/contrib/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
from pathlib import Path
7+
8+
import pandas as pd
9+
10+
11+
def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame:
12+
if isinstance(order_file, pd.DataFrame):
13+
return order_file
14+
15+
order_file = Path(order_file)
16+
17+
if order_file.suffix == ".pkl":
18+
order_df = pd.read_pickle(order_file).reset_index()
19+
elif order_file.suffix == ".csv":
20+
order_df = pd.read_csv(order_file)
21+
else:
22+
raise TypeError(f"Unsupported order file type: {order_file}")
23+
24+
if "date" in order_df.columns:
25+
# legacy dataframe columns
26+
order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"})
27+
order_df["datetime"] = order_df["datetime"].astype(str)
28+
29+
return order_df

0 commit comments

Comments
 (0)