Skip to content

Commit ccc3f96

Browse files
committed
Pass mypy
1 parent 83d8f00 commit ccc3f96

File tree

7 files changed

+39
-34
lines changed

7 files changed

+39
-34
lines changed

qlib/rl/from_neutrader/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
@dataclass
1010
class ExchangeConfig:
1111
limit_threshold: Union[float, Tuple[str, str]]
12-
deal_price: Union[str, Tuple[str, str]]
12+
deal_price: Union[str, Tuple[str]]
1313
volume_threshold: dict
1414
open_cost: float = 0.0005
1515
close_cost: float = 0.0015

qlib/rl/from_neutrader/feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
class LRUCache:
1616
def __init__(self, pool_size: int = 200):
1717
self.pool_size = pool_size
18-
self.contents = dict()
19-
self.keys = collections.deque()
18+
self.contents: dict = {}
19+
self.keys: collections.deque = collections.deque()
2020

2121
def put(self, key, item):
2222
if self.has(key):
@@ -52,7 +52,7 @@ def __init__(
5252
self.feature_cache = LRUCache()
5353
self.backtest_cache = LRUCache()
5454

55-
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False):
55+
def get(self, stock_id: str, date: pd.Timestamp, backtest: bool = False) -> pd.DataFrame:
5656
start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59)
5757

5858
if backtest:

qlib/rl/order_execution/interpreter.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,11 @@ def interpret(self, state: SAOEState) -> CurrentStateObs:
165165
assert self.env is not None
166166
assert self.env.status["cur_step"] <= self.max_step
167167
obs = CurrentStateObs(
168-
**{
169-
"acquiring": state.order.direction == state.order.BUY,
170-
"cur_step": self.env.status["cur_step"],
171-
"num_step": self.max_step,
172-
"target": state.order.amount,
173-
"position": state.position,
174-
}
168+
acquiring=state.order.direction == state.order.BUY,
169+
cur_step=self.env.status["cur_step"],
170+
num_step=self.max_step,
171+
target=state.order.amount,
172+
position=state.position,
175173
)
176174
return obs
177175

qlib/rl/order_execution/simulator_qlib.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""Placeholder for qlib-based simulator."""
55
from __future__ import annotations
66

7-
from typing import Callable, Generator, List, Optional, Tuple, cast
7+
from typing import Any, Callable, Generator, List, Optional, Tuple, cast
88

99
import numpy as np
1010
import pandas as pd
@@ -36,7 +36,7 @@ def __init__(self) -> None:
3636
self.execute_order: Optional[Order] = None
3737
self.execute_result: List[Tuple[Order, float, float, float]] = []
3838

39-
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
39+
def generate_trade_decision(self, execute_result: list = None) -> Generator[Any, Any, BaseTradeDecision]:
4040
exec_vol = yield self
4141

4242
oh = self.trade_exchange.get_order_helper()
@@ -52,7 +52,7 @@ def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) ->
5252
def post_exe_step(self, execute_result: list) -> None:
5353
self.execute_result = execute_result
5454

55-
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs) -> None:
55+
def reset(self, outer_trade_decision: TradeDecisionWO = None, **kwargs: Any) -> None:
5656
super().reset(outer_trade_decision=outer_trade_decision, **kwargs)
5757
if outer_trade_decision is not None:
5858
order_list = outer_trade_decision.order_list
@@ -83,7 +83,7 @@ def generate_trade_decision(self, execute_result: list = None) -> TradeDecisionW
8383
oh.create(
8484
code=self._instrument,
8585
amount=self._order.amount,
86-
direction=Order.parse_dir(self._order.direction),
86+
direction=self._order.direction,
8787
),
8888
]
8989
return TradeDecisionWO(order_list, self, self._trade_range)
@@ -102,7 +102,7 @@ def __init__(self, order: Order, tick_index: pd.DatetimeIndex, twap_price: float
102102
# NOTE: can empty dataframe contain index?
103103
self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime")
104104
self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime")
105-
self.metrics = None
105+
self.metrics: Optional[SAOEMetrics] = None
106106

107107
def update(
108108
self,
@@ -116,6 +116,8 @@ def update(
116116
exec_vol = np.array([e[0].deal_amount for e in execute_result])
117117
num_step = len(execute_result)
118118

119+
assert execute_order is not None
120+
119121
if num_step == 0:
120122
market_volume = np.array([])
121123
market_price = np.array([])
@@ -251,7 +253,7 @@ def __init__(
251253
exchange_config: ExchangeConfig,
252254
) -> None:
253255
super().__init__(
254-
initial=None, # TODO
256+
initial=order, # TODO: confirm this logic
255257
)
256258

257259
assert order.start_time.date() == order.end_time.date()
@@ -330,6 +332,8 @@ def reset(self, order: Order) -> None:
330332
)
331333

332334
def _iter_strategy(self, action: float = None) -> DecomposedStrategy:
335+
assert self._collect_data_loop is not None
336+
333337
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
334338
while not isinstance(strategy, DecomposedStrategy):
335339
strategy = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action)
@@ -344,6 +348,7 @@ def step(self, action: float) -> None:
344348
except StopIteration:
345349
self._done = True
346350

351+
assert self._executor is not None
347352
_, all_indicators = get_portfolio_and_indicator(self._executor)
348353

349354
self._maintainer.update(

qlib/rl/reward.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def reward(self, simulator_state: SimulatorState) -> float:
3131
raise NotImplementedError("Implement reward calculation recipe in `reward()`.")
3232

3333
def log(self, name: str, value: Any) -> None:
34+
assert self.env is not None
3435
self.env.logger.add_scalar(name, value)
3536

3637

qlib/rl/utils/data_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __enter__(self) -> DataQueue:
8484
self.activate()
8585
return self
8686

87-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
87+
def __exit__(self, exc_type, exc_val, exc_tb):
8888
self.cleanup()
8989

9090
def cleanup(self) -> None:

qlib/rl/utils/finite_env.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ def fill_invalid(obj: int | float | bool | np.ndarray | dict | list | tuple) ->
5757

5858
def is_invalid(arr: int | float | bool | np.ndarray | dict | list | tuple) -> bool:
5959
if hasattr(arr, "dtype"):
60-
if np.issubdtype(arr.dtype, np.floating):
60+
dtype = getattr(arr, "dtype")
61+
if np.issubdtype(dtype, np.floating):
6162
return np.isnan(arr).all()
62-
return (np.iinfo(arr.dtype).max == arr).all()
63+
return (np.iinfo(dtype).max == arr).all()
6364
if isinstance(arr, dict):
6465
return all(is_invalid(o) for o in arr.values())
6566
if isinstance(arr, (list, tuple)):
@@ -209,7 +210,7 @@ def collector_guard(self) -> Generator[FiniteVectorEnv, None, None]:
209210

210211
def reset(
211212
self,
212-
id: int | List[int] | np.ndarray = None,
213+
id: int | List[int] | np.ndarray | None = None,
213214
) -> np.ndarray:
214215
assert not self._zombie
215216

@@ -222,23 +223,23 @@ def reset(
222223
RuntimeWarning,
223224
)
224225

225-
id = self._wrap_id(id)
226+
wrapped_id = self._wrap_id(id)
226227
self._reset_alive_envs()
227228

228229
# ask super to reset alive envs and remap to current index
229-
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
230-
obs = [None] * len(id)
231-
id2idx = {i: k for k, i in enumerate(id)}
230+
request_id = [i for i in wrapped_id if i in self._alive_env_ids]
231+
obs = [None] * len(wrapped_id)
232+
id2idx = {i: k for k, i in enumerate(wrapped_id)}
232233
if request_id:
233234
for i, o in zip(request_id, super().reset(request_id)):
234235
obs[id2idx[i]] = self._postproc_env_obs(o)
235236

236-
for i, o in zip(id, obs):
237+
for i, o in zip(wrapped_id, obs):
237238
if o is None and i in self._alive_env_ids:
238239
self._alive_env_ids.remove(i)
239240

240241
# logging
241-
for i, o in zip(id, obs):
242+
for i, o in zip(wrapped_id, obs):
242243
if i in self._alive_env_ids:
243244
for logger in self._logger:
244245
logger.on_env_reset(i, obs)
@@ -251,7 +252,7 @@ def reset(
251252
obs[i] = self._get_default_obs()
252253

253254
if not self._alive_env_ids:
254-
# comment this line so that the env becomes indisposable
255+
# comment this line so that the env becomes indispensable
255256
# self.reset()
256257
self._zombie = True
257258
raise StopIteration
@@ -261,13 +262,13 @@ def reset(
261262
def step(
262263
self,
263264
action: np.ndarray,
264-
id: int | List[int] | np.ndarray = None,
265+
id: int | List[int] | np.ndarray | None = None,
265266
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
266267
assert not self._zombie
267-
id = self._wrap_id(id)
268-
id2idx = {i: k for k, i in enumerate(id)}
269-
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
270-
result = [[None, None, False, None] for _ in range(len(id))]
268+
wrapped_id = self._wrap_id(id)
269+
id2idx = {i: k for k, i in enumerate(wrapped_id)}
270+
request_id = list(filter(lambda i: i in self._alive_env_ids, wrapped_id))
271+
result = [[None, None, False, None] for _ in range(len(wrapped_id))]
271272

272273
# ask super to step alive envs and remap to current index
273274
if request_id:
@@ -277,7 +278,7 @@ def step(
277278
result[id2idx[i]][0] = self._postproc_env_obs(result[id2idx[i]][0])
278279

279280
# logging
280-
for i, r in zip(id, result):
281+
for i, r in zip(wrapped_id, result):
281282
if i in self._alive_env_ids:
282283
for logger in self._logger:
283284
logger.on_env_step(i, *r)

0 commit comments

Comments
 (0)