|
5 | 5 | """ |
6 | 6 | import abc |
7 | 7 | import copy |
| 8 | +import pandas as pd |
8 | 9 | from typing import List, Union, Callable |
9 | 10 |
|
10 | 11 | from qlib.utils import transform_end_date |
@@ -139,6 +140,53 @@ def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Unio |
139 | 140 | self.test_key = "test" |
140 | 141 | self.train_key = "train" |
141 | 142 |
|
| 143 | + def _update_task_segs(self, task, segs): |
| 144 | + # update segments of this task |
| 145 | + task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs) |
| 146 | + if self.ds_extra_mod_func is not None: |
| 147 | + self.ds_extra_mod_func(task, self) |
| 148 | + |
| 149 | + def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]: |
| 150 | + """ |
| 151 | + generating following rolling tasks for `task` until test_end |
| 152 | +
|
| 153 | + Parameters |
| 154 | + ---------- |
| 155 | + task : dict |
| 156 | + Qlib task format |
| 157 | + test_end : pd.Timestamp |
| 158 | + the latest rolling task includes `test_end` |
| 159 | +
|
| 160 | + Returns |
| 161 | + ------- |
| 162 | + List[dict]: |
| 163 | + the following tasks of `task`(`task` itself is excluded) |
| 164 | + """ |
| 165 | + t = copy.deepcopy(task) |
| 166 | + prev_seg = t["dataset"]["kwargs"]["segments"] |
| 167 | + while True: |
| 168 | + segments = {} |
| 169 | + try: |
| 170 | + for k, seg in prev_seg.items(): |
| 171 | + # decide how to shift |
| 172 | + # expanding only for train data, the segments size of test data and valid data won't change |
| 173 | + if k == self.train_key and self.rtype == self.ROLL_EX: |
| 174 | + rtype = self.ta.SHIFT_EX |
| 175 | + else: |
| 176 | + rtype = self.ta.SHIFT_SD |
| 177 | + # shift the segments data |
| 178 | + segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) |
| 179 | + if segments[self.test_key][0] > test_end: |
| 180 | + break |
| 181 | + except KeyError: |
| 182 | + # We reach the end of tasks |
| 183 | + # No more rolling |
| 184 | + break |
| 185 | + |
| 186 | + prev_seg = segments |
| 187 | + self._update_task_segs(t, segments) |
| 188 | + yield t |
| 189 | + |
142 | 190 | def generate(self, task: dict) -> List[dict]: |
143 | 191 | """ |
144 | 192 | Converting the task into a rolling task. |
@@ -191,43 +239,23 @@ def generate(self, task: dict) -> List[dict]: |
191 | 239 | """ |
192 | 240 | res = [] |
193 | 241 |
|
194 | | - prev_seg = None |
195 | | - test_end = None |
196 | | - while True: |
197 | | - t = copy.deepcopy(task) |
198 | | - |
199 | | - # calculate segments |
200 | | - if prev_seg is None: |
201 | | - # First rolling |
202 | | - # 1) prepare the end point |
203 | | - segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) |
204 | | - test_end = transform_end_date(segments[self.test_key][1]) |
205 | | - # 2) and init test segments |
206 | | - test_start_idx = self.ta.align_idx(segments[self.test_key][0]) |
207 | | - segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) |
208 | | - else: |
209 | | - segments = {} |
210 | | - try: |
211 | | - for k, seg in prev_seg.items(): |
212 | | - # decide how to shift |
213 | | - # expanding only for train data, the segments size of test data and valid data won't change |
214 | | - if k == self.train_key and self.rtype == self.ROLL_EX: |
215 | | - rtype = self.ta.SHIFT_EX |
216 | | - else: |
217 | | - rtype = self.ta.SHIFT_SD |
218 | | - # shift the segments data |
219 | | - segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype) |
220 | | - if segments[self.test_key][0] > test_end: |
221 | | - break |
222 | | - except KeyError: |
223 | | - # We reach the end of tasks |
224 | | - # No more rolling |
225 | | - break |
| 242 | + t = copy.deepcopy(task) |
226 | 243 |
|
227 | | - # update segments of this task |
228 | | - t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) |
229 | | - prev_seg = segments |
230 | | - if self.ds_extra_mod_func is not None: |
231 | | - self.ds_extra_mod_func(t, self) |
232 | | - res.append(t) |
| 244 | + # calculate segments |
| 245 | + |
| 246 | + # First rolling |
| 247 | + # 1) prepare the end point |
| 248 | + segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) |
| 249 | + test_end = transform_end_date(segments[self.test_key][1]) |
| 250 | + # 2) and init test segments |
| 251 | + test_start_idx = self.ta.align_idx(segments[self.test_key][0]) |
| 252 | + segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) |
| 253 | + |
| 254 | + # update segments of this task |
| 255 | + self._update_task_segs(t, segments) |
| 256 | + |
| 257 | + res.append(t) |
| 258 | + |
| 259 | + # Update the following rolling |
| 260 | + res.extend(self.gen_following_tasks(t, test_end)) |
233 | 261 | return res |
0 commit comments