Skip to content

Commit 9ad6b55

Browse files
committed
refactor online serving rolling api
1 parent dafef0a commit 9ad6b55

File tree

6 files changed

+82
-55
lines changed

6 files changed

+82
-55
lines changed

examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ task:
7878
- class: PortAnaRecord
7979
module_path: qlib.workflow.record_temp
8080
kwargs:
81-
config: *port_analysis_config
81+
config: *port_analysis_config
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
xgboost

qlib/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False):
570570

571571

572572
def transform_end_date(end_date=None, freq="day"):
573-
"""get previous trading date
573+
"""handle the end date with various format
574+
574575
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
575576
Otherwise, returns the end_date
577+
576578
----------
577579
end_date: str
578580
end trading date

qlib/workflow/online/strategy.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from qlib.data.data import D
1111
from qlib.log import get_module_logger
1212
from qlib.model.ens.group import RollingGroup
13+
from qlib.utils import transform_end_date
1314
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
1415
from qlib.workflow.recorder import Recorder
1516
from qlib.workflow.task.collect import Collector, RecorderCollector
@@ -118,6 +119,7 @@ def __init__(
118119
task_template = [task_template]
119120
self.task_template = task_template
120121
self.rg = rolling_gen
122+
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
121123
self.tool = OnlineToolR(self.exp_name)
122124
self.ta = TimeAdjuster()
123125

@@ -174,28 +176,20 @@ def prepare_tasks(self, cur_time) -> List[dict]:
174176
Returns:
175177
List[dict]: a list of new tasks.
176178
"""
179+
# TODO: filter recorders by latest test segments is not a necessary
177180
latest_records, max_test = self._list_latest(self.tool.online_models())
178181
if max_test is None:
179182
self.logger.warn(f"No latest online recorders, no new tasks.")
180183
return []
181-
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
184+
calendar_latest = transform_end_date(cur_time)
182185
self.logger.info(
183186
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
184187
)
185-
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
186-
old_tasks = []
187-
tasks_tmp = []
188-
for rec in latest_records:
189-
task = rec.load_object("task")
190-
old_tasks.append(deepcopy(task))
191-
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
192-
# modify the test segment to generate new tasks
193-
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
194-
tasks_tmp.append(task)
195-
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
196-
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
197-
return new_tasks
198-
return []
188+
res = []
189+
for rec in latest_records:
190+
task = rec.load_object("task")
191+
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
192+
return res
199193

200194
def _list_latest(self, rec_list: List[Recorder]):
201195
"""

qlib/workflow/online/update.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"
105105
if to_date == None:
106106
to_date = D.calendar(freq=freq)[-1]
107107
self.to_date = pd.Timestamp(to_date)
108+
# FIXME: it will raise error when running routine with delay trainer
109+
# should we use another predicition updater for delay trainer?
108110
self.old_pred = record.load_object("pred.pkl")
109111
self.last_end = self.old_pred.index.get_level_values("datetime").max()
110112

qlib/workflow/task/gen.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import abc
77
import copy
8+
import pandas as pd
89
from typing import List, Union, Callable
910

1011
from qlib.utils import transform_end_date
@@ -139,6 +140,53 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Unio
139140
self.test_key = "test"
140141
self.train_key = "train"
141142

143+
def _update_task_segs(self, task, segs):
144+
# update segments of this task
145+
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
146+
if self.ds_extra_mod_func is not None:
147+
self.ds_extra_mod_func(task, self)
148+
149+
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
150+
"""
151+
generating following rolling tasks for `task` until test_end
152+
153+
Parameters
154+
----------
155+
task : dict
156+
Qlib task format
157+
test_end : pd.Timestamp
158+
the latest rolling task includes `test_end`
159+
160+
Returns
161+
-------
162+
List[dict]:
163+
the following tasks of `task`(`task` itself is excluded)
164+
"""
165+
t = copy.deepcopy(task)
166+
prev_seg = t["dataset"]["kwargs"]["segments"]
167+
while True:
168+
segments = {}
169+
try:
170+
for k, seg in prev_seg.items():
171+
# decide how to shift
172+
# expanding only for train data, the segments size of test data and valid data won't change
173+
if k == self.train_key and self.rtype == self.ROLL_EX:
174+
rtype = self.ta.SHIFT_EX
175+
else:
176+
rtype = self.ta.SHIFT_SD
177+
# shift the segments data
178+
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
179+
if segments[self.test_key][0] > test_end:
180+
break
181+
except KeyError:
182+
# We reach the end of tasks
183+
# No more rolling
184+
break
185+
186+
prev_seg = segments
187+
self._update_task_segs(t, segments)
188+
yield t
189+
142190
def generate(self, task: dict) -> List[dict]:
143191
"""
144192
Converting the task into a rolling task.
@@ -191,43 +239,23 @@ def generate(self, task: dict) -> List[dict]:
191239
"""
192240
res = []
193241

194-
prev_seg = None
195-
test_end = None
196-
while True:
197-
t = copy.deepcopy(task)
198-
199-
# calculate segments
200-
if prev_seg is None:
201-
# First rolling
202-
# 1) prepare the end point
203-
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
204-
test_end = transform_end_date(segments[self.test_key][1])
205-
# 2) and init test segments
206-
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
207-
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
208-
else:
209-
segments = {}
210-
try:
211-
for k, seg in prev_seg.items():
212-
# decide how to shift
213-
# expanding only for train data, the segments size of test data and valid data won't change
214-
if k == self.train_key and self.rtype == self.ROLL_EX:
215-
rtype = self.ta.SHIFT_EX
216-
else:
217-
rtype = self.ta.SHIFT_SD
218-
# shift the segments data
219-
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
220-
if segments[self.test_key][0] > test_end:
221-
break
222-
except KeyError:
223-
# We reach the end of tasks
224-
# No more rolling
225-
break
242+
t = copy.deepcopy(task)
226243

227-
# update segments of this task
228-
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
229-
prev_seg = segments
230-
if self.ds_extra_mod_func is not None:
231-
self.ds_extra_mod_func(t, self)
232-
res.append(t)
244+
# calculate segments
245+
246+
# First rolling
247+
# 1) prepare the end point
248+
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
249+
test_end = transform_end_date(segments[self.test_key][1])
250+
# 2) and init test segments
251+
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
252+
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
253+
254+
# update segments of this task
255+
self._update_task_segs(t, segments)
256+
257+
res.append(t)
258+
259+
# Update the following rolling
260+
res.extend(self.gen_following_tasks(t, test_end))
233261
return res

0 commit comments

Comments
 (0)