2
2
# Licensed under the MIT License.
3
3
from __future__ import annotations
4
4
5
+ import argparse
5
6
import copy
6
7
import pickle
7
- import sys
8
+ from collections import defaultdict
8
9
from pathlib import Path
9
- from typing import Optional , Tuple , Union
10
+ from typing import List , Literal , Optional , Tuple , Union
10
11
11
12
import numpy as np
12
13
import pandas as pd
13
14
import torch
14
15
from joblib import Parallel , delayed
15
16
16
17
from qlib .backtest import collect_data_loop , get_strategy_executor
17
- from qlib .backtest .decision import TradeRangeByTime
18
+ from qlib .backtest .decision import BaseTradeDecision , Order , OrderDir , TradeRangeByTime
18
19
from qlib .backtest .executor import BaseExecutor , NestedExecutor , SimulatorExecutor
19
20
from qlib .backtest .high_performance_ds import BaseOrderIndicator
20
21
from qlib .rl .contrib .naive_config_parser import get_backtest_config_fromfile
21
22
from qlib .rl .contrib .utils import read_order_file
22
23
from qlib .rl .data .integration import init_qlib
24
+ from qlib .rl .order_execution .simulator_qlib import SingleAssetOrderExecution
23
25
from qlib .rl .utils .env_wrapper import CollectDataEnvWrapper
24
26
25
27
@@ -41,7 +43,7 @@ def _get_multi_level_executor_config(
41
43
}
42
44
43
45
freqs = list (strategy_config .keys ())
44
- freqs .sort (key = lambda x : pd .Timedelta ( x ) )
46
+ freqs .sort (key = pd .Timedelta )
45
47
for freq in freqs :
46
48
executor_config = {
47
49
"class" : "NestedExecutor" ,
@@ -73,7 +75,7 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
73
75
# HACK: for qlib v0.8
74
76
value_dict = value_dict .to_series ()
75
77
try :
76
- value_dict = { k : v for k , v in value_dict . items ()}
78
+ value_dict = copy . deepcopy ( value_dict )
77
79
if value_dict ["ffr" ].empty :
78
80
continue
79
81
except Exception :
@@ -90,32 +92,177 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
90
92
return records
91
93
92
94
93
- def _generate_report (decisions : list , report_dict : dict ) -> dict :
95
+ # TODO: there should be richer annotation for the input (e.g. report) and the returned report
96
+ # TODO: For example, @ dataclass with typed fields and detailed docstrings.
97
+ def _generate_report (decisions : List [BaseTradeDecision ], report_indicators : List [dict ]) -> dict :
98
+ """Generate backtest reports
99
+
100
+ Parameters
101
+ ----------
102
+ decisions:
103
+ List of trade decisions.
104
+ report_indicators
105
+ List of indicator reports.
106
+ Returns
107
+ -------
108
+
109
+ """
110
+ indicator_dict = defaultdict (list )
111
+ indicator_his = defaultdict (list )
112
+ for report_indicator in report_indicators :
113
+ for key , value in report_indicator .items ():
114
+ if key .endswith ("_obj" ):
115
+ indicator_his [key ].append (value .order_indicator_his )
116
+ else :
117
+ indicator_dict [key ].append (value )
118
+
94
119
report = {}
95
- decision_details = pd .concat ([d . details for d in decisions if hasattr (d , "details" )])
96
- for key in ["1minute " , "5minute " , "30minute " , "1day" ]:
97
- if key not in report_dict [ "indicator" ] :
120
+ decision_details = pd .concat ([getattr ( d , " details" ) for d in decisions if hasattr (d , "details" )])
121
+ for key in ["1min " , "5min " , "30min " , "1day" ]:
122
+ if key not in indicator_dict :
98
123
continue
99
- report [ key ] = report_dict [ "indicator" ][ key ]
100
- report [key + "_obj" ] = _convert_indicator_to_dataframe (
101
- report_dict [ "indicator" ][ key + "_obj" ]. order_indicator_his
102
- )
103
- cur_details = decision_details [decision_details .freq == key . rstrip ( "ute" ) ].set_index (["instrument" , "datetime" ])
124
+
125
+ report [key ] = pd . concat ( indicator_dict [ key ])
126
+ report [ key + "_obj" ] = pd . concat ([ _convert_indicator_to_dataframe ( his ) for his in indicator_his [ key + "_obj" ]])
127
+
128
+ cur_details = decision_details [decision_details .freq == key ].set_index (["instrument" , "datetime" ])
104
129
if len (cur_details ) > 0 :
105
130
cur_details .pop ("freq" )
106
131
report [key + "_obj" ] = report [key + "_obj" ].join (cur_details , how = "outer" )
107
- if "1minute" in report_dict ["report" ]:
108
- report ["simulator" ] = report_dict ["report" ]["1minute" ][0 ]
132
+
109
133
return report
110
134
111
135
112
- def single (
136
+ def single_with_simulator (
113
137
backtest_config : dict ,
114
138
orders : pd .DataFrame ,
115
- split : str = "stock" ,
139
+ split : Literal [ "stock" , "day" ] = "stock" ,
116
140
cash_limit : float = None ,
117
141
generate_report : bool = False ,
118
142
) -> Union [Tuple [pd .DataFrame , dict ], pd .DataFrame ]:
143
+ """Run backtest in a single thread with SingleAssetOrderExecution simulator. The orders will be executed day by day.
144
+ A new simulator will be created and used for every single-day order.
145
+
146
+ Parameters
147
+ ----------
148
+ backtest_config:
149
+ Backtest config
150
+ orders:
151
+ Orders to be executed. Example format:
152
+ datetime instrument amount direction
153
+ 0 2020-06-01 INST 600.0 0
154
+ 1 2020-06-02 INST 700.0 1
155
+ ...
156
+ split
157
+ Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
158
+ cash_limit
159
+ Limitation of cash.
160
+ generate_report
161
+ Whether to generate reports.
162
+
163
+ Returns
164
+ -------
165
+ If generate_report is True, return execution records and the generated report. Otherwise, return only records.
166
+ """
167
+ if split == "stock" :
168
+ stock_id = orders .iloc [0 ].instrument
169
+ init_qlib (backtest_config ["qlib" ], part = stock_id )
170
+ else :
171
+ day = orders .iloc [0 ].datetime
172
+ init_qlib (backtest_config ["qlib" ], part = day )
173
+
174
+ stocks = orders .instrument .unique ().tolist ()
175
+
176
+ reports = []
177
+ decisions = []
178
+ for _ , row in orders .iterrows ():
179
+ date = pd .Timestamp (row ["datetime" ])
180
+ start_time = pd .Timestamp (backtest_config ["start_time" ]).replace (year = date .year , month = date .month , day = date .day )
181
+ end_time = pd .Timestamp (backtest_config ["end_time" ]).replace (year = date .year , month = date .month , day = date .day )
182
+ order = Order (
183
+ stock_id = row ["instrument" ],
184
+ amount = row ["amount" ],
185
+ direction = OrderDir (row ["direction" ]),
186
+ start_time = start_time ,
187
+ end_time = end_time ,
188
+ )
189
+
190
+ executor_config = _get_multi_level_executor_config (
191
+ strategy_config = backtest_config ["strategies" ],
192
+ cash_limit = cash_limit ,
193
+ generate_report = generate_report ,
194
+ )
195
+
196
+ exchange_config = copy .deepcopy (backtest_config ["exchange" ])
197
+ exchange_config .update (
198
+ {
199
+ "codes" : stocks ,
200
+ "freq" : "1min" ,
201
+ }
202
+ )
203
+
204
+ simulator = SingleAssetOrderExecution (
205
+ order = order ,
206
+ executor_config = executor_config ,
207
+ exchange_config = exchange_config ,
208
+ qlib_config = None ,
209
+ cash_limit = None ,
210
+ backtest_mode = True ,
211
+ )
212
+
213
+ reports .append (simulator .report_dict )
214
+ decisions += simulator .decisions
215
+
216
+ indicator = {k : v for report in reports for k , v in report ["indicator" ]["1day_obj" ].order_indicator_his .items ()}
217
+ records = _convert_indicator_to_dataframe (indicator )
218
+ assert records is None or not np .isnan (records ["ffr" ]).any ()
219
+
220
+ if generate_report :
221
+ report = _generate_report (decisions , [report ["indicator" ] for report in reports ])
222
+
223
+ if split == "stock" :
224
+ stock_id = orders .iloc [0 ].instrument
225
+ report = {stock_id : report }
226
+ else :
227
+ day = orders .iloc [0 ].datetime
228
+ report = {day : report }
229
+
230
+ return records , report
231
+ else :
232
+ return records
233
+
234
+
235
+ def single_with_collect_data_loop (
236
+ backtest_config : dict ,
237
+ orders : pd .DataFrame ,
238
+ split : Literal ["stock" , "day" ] = "stock" ,
239
+ cash_limit : float = None ,
240
+ generate_report : bool = False ,
241
+ ) -> Union [Tuple [pd .DataFrame , dict ], pd .DataFrame ]:
242
+ """Run backtest in a single thread with collect_data_loop.
243
+
244
+ Parameters
245
+ ----------
246
+ backtest_config:
247
+ Backtest config
248
+ orders:
249
+ Orders to be executed. Example format:
250
+ datetime instrument amount direction
251
+ 0 2020-06-01 INST 600.0 0
252
+ 1 2020-06-02 INST 700.0 1
253
+ ...
254
+ split
255
+ Method to split orders. If it is "stock", split orders by stock. If it is "day", split orders by date.
256
+ cash_limit
257
+ Limitation of cash.
258
+ generate_report
259
+ Whether to generate reports.
260
+
261
+ Returns
262
+ -------
263
+ If generate_report is True, return execution records and the generated report. Otherwise, return only records.
264
+ """
265
+
119
266
if split == "stock" :
120
267
stock_id = orders .iloc [0 ].instrument
121
268
init_qlib (backtest_config ["qlib" ], part = stock_id )
@@ -127,7 +274,7 @@ def single(
127
274
trade_end_time = orders ["datetime" ].max ()
128
275
stocks = orders .instrument .unique ().tolist ()
129
276
130
- top_strategy_config = {
277
+ strategy_config = {
131
278
"class" : "FileOrderStrategy" ,
132
279
"module_path" : "qlib.contrib.strategy.rule_strategy" ,
133
280
"kwargs" : {
@@ -139,14 +286,14 @@ def single(
139
286
},
140
287
}
141
288
142
- top_executor_config = _get_multi_level_executor_config (
289
+ executor_config = _get_multi_level_executor_config (
143
290
strategy_config = backtest_config ["strategies" ],
144
291
cash_limit = cash_limit ,
145
292
generate_report = generate_report ,
146
293
)
147
294
148
- tmp_backtest_config = copy .deepcopy (backtest_config ["exchange" ])
149
- tmp_backtest_config .update (
295
+ exchange_config = copy .deepcopy (backtest_config ["exchange" ])
296
+ exchange_config .update (
150
297
{
151
298
"codes" : stocks ,
152
299
"freq" : "1min" ,
@@ -156,11 +303,11 @@ def single(
156
303
strategy , executor = get_strategy_executor (
157
304
start_time = pd .Timestamp (trade_start_time ),
158
305
end_time = pd .Timestamp (trade_end_time ) + pd .DateOffset (1 ),
159
- strategy = top_strategy_config ,
160
- executor = top_executor_config ,
306
+ strategy = strategy_config ,
307
+ executor = executor_config ,
161
308
benchmark = None ,
162
309
account = cash_limit if cash_limit is not None else int (1e12 ),
163
- exchange_kwargs = tmp_backtest_config ,
310
+ exchange_kwargs = exchange_config ,
164
311
pos_type = "Position" if cash_limit is not None else "InfPosition" ,
165
312
)
166
313
_set_env_for_all_strategy (executor = executor )
@@ -172,7 +319,7 @@ def single(
172
319
assert records is None or not np .isnan (records ["ffr" ]).any ()
173
320
174
321
if generate_report :
175
- report = _generate_report (decisions , report_dict )
322
+ report = _generate_report (decisions , [ report_dict [ "indicator" ]] )
176
323
if split == "stock" :
177
324
stock_id = orders .iloc [0 ].instrument
178
325
report = {stock_id : report }
@@ -184,7 +331,7 @@ def single(
184
331
return records
185
332
186
333
187
- def backtest (backtest_config : dict ) -> pd .DataFrame :
334
+ def backtest (backtest_config : dict , with_simulator : bool = False ) -> pd .DataFrame :
188
335
order_df = read_order_file (backtest_config ["order_file" ])
189
336
190
337
cash_limit = backtest_config ["exchange" ].pop ("cash_limit" )
@@ -193,6 +340,7 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
193
340
stock_pool = order_df ["instrument" ].unique ().tolist ()
194
341
stock_pool .sort ()
195
342
343
+ single = single_with_simulator if with_simulator else single_with_collect_data_loop
196
344
mp_config = {"n_jobs" : backtest_config ["concurrency" ], "verbose" : 10 , "backend" : "multiprocessing" }
197
345
torch .set_num_threads (1 ) # https://github.com/pytorch/pytorch/issues/17199
198
346
res = Parallel (** mp_config )(
@@ -227,5 +375,12 @@ def backtest(backtest_config: dict) -> pd.DataFrame:
227
375
warnings .filterwarnings ("ignore" , category = DeprecationWarning )
228
376
warnings .filterwarnings ("ignore" , category = RuntimeWarning )
229
377
230
- path = sys .argv [1 ]
231
- backtest (get_backtest_config_fromfile (path ))
378
+ parser = argparse .ArgumentParser ()
379
+ parser .add_argument ("--config_path" , type = str , required = True , help = "Path to the config file" )
380
+ parser .add_argument ("--use_simulator" , action = "store_true" , help = "Whether to use simulator as the backend" )
381
+ args = parser .parse_args ()
382
+
383
+ backtest (
384
+ backtest_config = get_backtest_config_fromfile (args .config_path ),
385
+ with_simulator = args .use_simulator ,
386
+ )
0 commit comments