Skip to content

Commit 54ef210

Browse files
lihuoranluocy16
andauthored
Migrate amc4th training (microsoft#1316)
* Migrate amc4th training * Refine RL example scripts * Resolve PR comments Co-authored-by: luocy16 <luocy16@mails.tsinghua.edu.cn>
1 parent f35d8a3 commit 54ef210

19 files changed

+676
-50
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ qlib/VERSION.txt
2424
qlib/data/_libs/expanding.cpp
2525
qlib/data/_libs/rolling.cpp
2626
examples/estimator/estimator_example/
27+
examples/rl/data/
28+
examples/rl/checkpoints/
29+
examples/rl/outputs/
2730

2831
*.egg-info/
2932

examples/rl/README.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
This folder contains a simple example of how to run Qlib RL. It contains:
2+
3+
```
4+
.
5+
├── experiment_config
6+
│ ├── backtest # Backtest config
7+
│ └── training # Training config
8+
├── README.md # Readme (the current file)
9+
└── scripts # Scripts for data pre-processing
10+
```
11+
12+
## Data preparation
13+
14+
Use [AzCopy](https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to download data:
15+
16+
```
17+
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl/qlib_rl_example_data ./ --recursive
18+
mv qlib_rl_example_data data
19+
```
20+
21+
The downloaded data will be placed at `./data`. The original data are in `data/csv`. To create all data needed by the case, run:
22+
23+
```
24+
bash scripts/data_pipeline.sh
25+
```
26+
27+
After the execution finishes, the `data/` directory should be like:
28+
29+
```
30+
data
31+
├── backtest_orders.csv
32+
├── bin
33+
├── csv
34+
├── pickle
35+
├── pickle_dataframe
36+
└── training_order_split
37+
```
38+
39+
## Run training
40+
41+
Run:
42+
43+
```
44+
python ../../qlib/rl/contrib/train_onpolicy.py --config_path ./experiment_config/training/config.yml
45+
```
46+
47+
After training, checkpoints will be stored under `checkpoints/`.
48+
49+
## Run backtest
50+
51+
```
52+
python ../../qlib/rl/contrib/backtest.py --config_path ./experiment_config/backtest/config.py
53+
```
54+
55+
The backtest workflow will use the trained model in `checkpoints/`. The backtest summary can be found in `outputs/`.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
_base_ = ["./twap.yml"]
2+
3+
strategies = {
4+
"_delete_": True,
5+
"30min": {
6+
"class": "TWAPStrategy",
7+
"module_path": "qlib.contrib.strategy.rule_strategy",
8+
"kwargs": {},
9+
},
10+
"1day": {
11+
"class": "SAOEIntStrategy",
12+
"module_path": "qlib.rl.order_execution.strategy",
13+
"kwargs": {
14+
"state_interpreter": {
15+
"class": "FullHistoryStateInterpreter",
16+
"module_path": "qlib.rl.order_execution.interpreter",
17+
"kwargs": {
18+
"max_step": 8,
19+
"data_ticks": 240,
20+
"data_dim": 6,
21+
"processed_data_provider": {
22+
"class": "PickleProcessedDataProvider",
23+
"module_path": "qlib.rl.data.pickle_styled",
24+
"kwargs": {
25+
"data_dir": "./data/pickle_dataframe/feature",
26+
},
27+
},
28+
},
29+
},
30+
"action_interpreter": {
31+
"class": "CategoricalActionInterpreter",
32+
"module_path": "qlib.rl.order_execution.interpreter",
33+
"kwargs": {
34+
"values": 14,
35+
"max_step": 8,
36+
},
37+
},
38+
"network": {
39+
"class": "Recurrent",
40+
"module_path": "qlib.rl.order_execution.network",
41+
"kwargs": {},
42+
},
43+
"policy": {
44+
"class": "PPO",
45+
"module_path": "qlib.rl.order_execution.policy",
46+
"kwargs": {
47+
"lr": 1.0e-4,
48+
"weight_file": "./checkpoints/latest.pth",
49+
},
50+
},
51+
},
52+
},
53+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
order_file: ./data/backtest_orders.csv
2+
start_time: "9:45"
3+
end_time: "14:44"
4+
qlib:
5+
provider_uri_1min: ./data/bin
6+
feature_root_dir: ./data/pickle
7+
feature_columns_today: [
8+
"$open", "$high", "$low", "$close", "$vwap", "$volume",
9+
]
10+
feature_columns_yesterday: [
11+
"$open_v1", "$high_v1", "$low_v1", "$close_v1", "$vwap_v1", "$volume_v1",
12+
]
13+
exchange:
14+
limit_threshold: ['$close == 0', '$close == 0']
15+
deal_price: ["If($close == 0, $vwap, $close)", "If($close == 0, $vwap, $close)"]
16+
volume_threshold:
17+
all: ["cum", "0.2 * DayCumsum($volume, '9:45', '14:44')"]
18+
buy: ["current", "$close"]
19+
sell: ["current", "$close"]
20+
strategies: {} # Placeholder
21+
concurrency: 5
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
simulator:
2+
time_per_step: 30
3+
vol_limit: null
4+
env:
5+
concurrency: 1
6+
parallel_mode: dummy
7+
action_interpreter:
8+
class: CategoricalActionInterpreter
9+
kwargs:
10+
values: 14
11+
max_step: 8
12+
module_path: qlib.rl.order_execution.interpreter
13+
state_interpreter:
14+
class: FullHistoryStateInterpreter
15+
kwargs:
16+
data_dim: 6
17+
data_ticks: 240
18+
max_step: 8
19+
processed_data_provider:
20+
class: PickleProcessedDataProvider
21+
module_path: qlib.rl.data.pickle_styled
22+
kwargs:
23+
data_dir: ./data/pickle_dataframe/feature
24+
module_path: qlib.rl.order_execution.interpreter
25+
reward:
26+
class: PAPenaltyReward
27+
kwargs:
28+
penalty: 100.0
29+
module_path: qlib.rl.order_execution.reward
30+
data:
31+
source:
32+
order_dir: ./data/training_order_split
33+
data_dir: ./data/pickle_dataframe/backtest
34+
total_time: 240
35+
default_start_time: 0
36+
default_end_time: 240
37+
proc_data_dim: 6
38+
num_workers: 0
39+
queue_size: 20
40+
network:
41+
class: Recurrent
42+
module_path: qlib.rl.order_execution.network
43+
policy:
44+
class: PPO
45+
kwargs:
46+
lr: 0.0001
47+
module_path: qlib.rl.order_execution.policy
48+
runtime:
49+
seed: 42
50+
use_cuda: false
51+
trainer:
52+
max_epoch: 2
53+
repeat_per_collect: 5
54+
earlystop_patience: 2
55+
episode_per_collect: 20
56+
batch_size: 16
57+
val_every_n_epoch: 1
58+
checkpoint_path: ./checkpoints
59+
checkpoint_every_n_iters: 1
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import os
5+
import pickle
6+
import pandas as pd
7+
from tqdm import tqdm
8+
9+
os.makedirs(os.path.join("data", "pickle_dataframe"), exist_ok=True)
10+
11+
for tag in ("backtest", "feature"):
12+
df = pickle.load(open(os.path.join("data", "pickle", f"{tag}.pkl"), "rb"))
13+
df = pd.concat(list(df.values())).reset_index()
14+
df["date"] = df["datetime"].dt.date.astype("datetime64")
15+
instruments = sorted(set(df["instrument"]))
16+
17+
os.makedirs(os.path.join("data", "pickle_dataframe", tag), exist_ok=True)
18+
for instrument in tqdm(instruments):
19+
cur = df[df["instrument"] == instrument].sort_values(by=["datetime"])
20+
cur = cur.set_index(["instrument", "datetime", "date"])
21+
pickle.dump(cur, open(os.path.join("data", "pickle_dataframe", tag, f"{instrument}.pkl"), "wb"))

examples/rl/scripts/data_pipeline.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Generate `bin` format data
2+
set -e
3+
python ../../scripts/dump_bin.py dump_all --csv_path ./data/csv --qlib_dir ./data/bin --include_fields open,close,high,low,vwap,volume --symbol_field_name symbol --date_field_name date --freq 1min
4+
5+
# Generate pickle format data
6+
python scripts/gen_pickle_data.py -c scripts/pickle_data_config.yml
7+
if [ -e stat/ ]; then
8+
rm -r stat/
9+
fi
10+
python scripts/collect_pickle_dataframe.py
11+
12+
# Sample orders
13+
python scripts/gen_training_orders.py
14+
python scripts/gen_backtest_orders.py
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
import os
6+
import pandas as pd
7+
import numpy as np
8+
import pickle
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--seed", type=int, default=20220926)
12+
parser.add_argument("--num_order", type=int, default=10)
13+
args = parser.parse_args()
14+
15+
np.random.seed(args.seed)
16+
17+
path = os.path.join("data", "pickle", "backtesttest.pkl") # TODO: rename file
18+
df = pickle.load(open(path, "rb")).reset_index()
19+
df["date"] = df["datetime"].dt.date.astype("datetime64")
20+
21+
instruments = sorted(set(df["instrument"]))
22+
df_list = []
23+
for instrument in instruments:
24+
print(instrument)
25+
26+
cur_df = df[df["instrument"] == instrument]
27+
28+
dates = sorted(set([str(d).split(" ")[0] for d in cur_df["date"]]))
29+
30+
n = args.num_order
31+
df_list.append(
32+
pd.DataFrame({
33+
"date": sorted(np.random.choice(dates, size=n, replace=False)),
34+
"instrument": [instrument] * n,
35+
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
36+
"order_type": np.random.randint(low=0, high=2, size=n),
37+
}).set_index(["date", "instrument"]),
38+
)
39+
40+
total_df = pd.concat(df_list)
41+
total_df.to_csv("data/backtest_orders.csv")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import yaml
5+
import argparse
6+
import os
7+
from copy import deepcopy
8+
9+
from qlib.contrib.data.highfreq_provider import HighFreqProvider
10+
11+
loader = yaml.FullLoader
12+
13+
if __name__ == "__main__":
14+
parser = argparse.ArgumentParser()
15+
parser.add_argument("-c", "--config", type=str, default="config.yml")
16+
parser.add_argument("-d", "--dest", type=str, default=".")
17+
parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock")
18+
args = parser.parse_args()
19+
20+
conf = yaml.load(open(args.config), Loader=loader)
21+
22+
for k, v in conf.items():
23+
if isinstance(v, dict) and "path" in v:
24+
v["path"] = os.path.join(args.dest, v["path"])
25+
provider = HighFreqProvider(**conf)
26+
27+
# Gen dataframe
28+
if "feature_conf" in conf:
29+
feature = provider._gen_dataframe(deepcopy(provider.feature_conf))
30+
if "backtest_conf" in conf:
31+
backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf))
32+
33+
provider.feature_conf['path'] = os.path.splitext(provider.feature_conf['path'])[0] + '/'
34+
provider.backtest_conf['path'] = os.path.splitext(provider.backtest_conf['path'])[0] + '/'
35+
# Split by date
36+
if args.split == "date" or args.split == "both":
37+
provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature")
38+
provider._gen_day_dataset(deepcopy(provider.backtest_conf), "backtest")
39+
40+
# Split by stock
41+
if args.split == "stock" or args.split == "both":
42+
provider._gen_stock_dataset(deepcopy(provider.feature_conf), "feature")
43+
provider._gen_stock_dataset(deepcopy(provider.backtest_conf), "backtest")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
import os
6+
import pandas as pd
7+
import numpy as np
8+
import pickle
9+
10+
parser = argparse.ArgumentParser()
11+
parser.add_argument("--seed", type=int, default=20220926)
12+
parser.add_argument("--stock", type=str, default="AAPL")
13+
parser.add_argument("--train_size", type=int, default=10)
14+
parser.add_argument("--valid_size", type=int, default=2)
15+
parser.add_argument("--test_size", type=int, default=2)
16+
args = parser.parse_args()
17+
18+
np.random.seed(args.seed)
19+
20+
os.makedirs(os.path.join("data", "training_order_split"), exist_ok=True)
21+
22+
for group, n in zip(("train", "valid", "test"), (args.train_size, args.valid_size, args.test_size)):
23+
path = os.path.join("data", "pickle", f"backtest{group}.pkl")
24+
df = pickle.load(open(path, "rb")).reset_index()
25+
df["date"] = df["datetime"].dt.date.astype("datetime64")
26+
27+
dates = sorted(set([str(d).split(" ")[0] for d in df["date"]]))
28+
29+
data_df = pd.DataFrame({
30+
"date": sorted(np.random.choice(dates, size=n, replace=False)),
31+
"instrument": [args.stock] * n,
32+
"amount": np.random.randint(low=3, high=11, size=n) * 100.0,
33+
"order_type": [0] * n,
34+
}).set_index(["date", "instrument"])
35+
36+
os.makedirs(os.path.join("data", "training_order_split", group), exist_ok=True)
37+
pickle.dump(data_df, open(os.path.join("data", "training_order_split", group, f"{args.stock}.pkl"), "wb"))

0 commit comments

Comments
 (0)