diff --git a/2.0.0/404.html b/2.0.0/404.html new file mode 100644 index 00000000..1572dd35 --- /dev/null +++ b/2.0.0/404.html @@ -0,0 +1,660 @@ + + + +
+ + + + + + + + + + + +
+Board
+
+
+
+¶omicron/models/board.py
class Board:
+ server_ip: str
+ server_port: int
+ measurement = "board_bars_1d"
+
+ @classmethod
+ def init(cls, ip: str, port: int = 3180):
+ cls.server_ip = ip
+ cls.server_port = port
+
+ @classmethod
+ async def _rpc_call(cls, url: str, param: str):
+ _url = f"http://{cls.server_ip}:{cls.server_port}/api/board/{url}"
+
+ async with httpx.AsyncClient() as client:
+ r = await client.post(_url, json=param, timeout=10)
+ if r.status_code != 200:
+ logger.error(
+ f"failed to post RPC call, {_url}: {param}, response: {r.content.decode()}"
+ )
+ return {"rc": r.status_code}
+
+ rsp = json.loads(r.content)
+ return {"rc": 200, "data": rsp}
+
+ @classmethod
+ async def board_list(cls, _btype: BoardType = BoardType.CONCEPT) -> List[List]:
+ """获取板块列表
+
+ Args:
+ _btype: 板块类别,可选值`BoardType.CONCEPT`和`BoardType.INDUSTRY`.
+
+ Returns:
+ 板块列表。每一个子元素仍为一个列表,由板块代码(str), 板块名称(str)和成员数组成。示例:
+ ```
+ [
+ ['881101', '种植业与林业', 24],
+ ['881102', '养殖业', 27],
+ ['881103', '农产品加工', 41],
+ ['881104', '农业服务', 16],
+ ]
+ ```
+ """
+ rsp = await cls._rpc_call("board_list", {"board_type": _btype.value})
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
+ @classmethod
+ async def fuzzy_match_board_name(
+ cls, pattern: str, _btype: BoardType = BoardType.CONCEPT
+ ) -> dict:
+ """模糊查询板块代码的名字
+
+ Examples:
+ ```python
+ await Board.fuzzy_match_board_name("汽车", BoardType.INDUSTRY)
+
+ # returns:
+ [
+ '881125 汽车整车',
+ '881126 汽车零部件',
+ '881127 非汽车交运',
+ '881128 汽车服务',
+ '884107 汽车服务Ⅲ',
+ '884194 汽车零部件Ⅲ'
+ ]
+ ```
+ Args:
+ pattern: 待查询模式串
+ _btype: 查询类型
+
+ Returns:
+ 包含以下key的dict: code(板块代码), name(板块名), stocks(股票数)
+ """
+ if not pattern:
+ return []
+
+ rsp = await cls._rpc_call(
+ "fuzzy_match_name", {"board_type": _btype.value, "pattern": pattern}
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
+ @classmethod
+ async def board_info_by_id(cls, board_id: str, full_mode: bool = False) -> dict:
+ """通过板块代码查询板块信息(名字,成员数目或清单)
+
+ Examples:
+ ```python
+ board_code = '881128' # 汽车服务 可自行修改
+ board_info = await Board.board_info_by_id(board_code)
+ print(board_info) # 字典形式
+
+ # returns
+ {'code': '881128', 'name': '汽车服务', 'stocks': 14}
+ ```
+
+ Returns:
+ {'code': '301505', 'name': '医疗器械概念', 'stocks': 242}
+ or
+ {'code': '301505', 'name': '医疗器械概念', 'stocks': [['300916', '朗特智能'], ['300760', '迈瑞医疗']]}
+ """
+
+ if not board_id:
+ return {}
+ if board_id[0] == "3":
+ _btype = BoardType.CONCEPT
+ else:
+ _btype = BoardType.INDUSTRY
+
+ _mode = 0
+ if full_mode: # 转换bool类型
+ _mode = 1
+
+ rsp = await cls._rpc_call(
+ "info",
+ {"board_type": _btype.value, "board_id": board_id, "fullmode": _mode},
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
+ @classmethod
+ async def board_info_by_security(
+ cls, security: str, _btype: BoardType = BoardType.CONCEPT
+ ) -> List[dict]:
+ """获取股票所在板块信息:名称,代码
+
+ Examples:
+ ```python
+ stock_code = '002236' # 大华股份,股票代码不带字母后缀
+ stock_in_board = await Board.board_info_by_security(stock_code, _btype=BoardType.CONCEPT)
+ print(stock_in_board)
+
+ # returns:
+ [
+ {'code': '301715', 'name': '证金持股', 'stocks': 208},
+ {'code': '308870', 'name': '数字经济', 'stocks': 195},
+ {'code': '308642', 'name': '数据中心', 'stocks': 188},
+ ...,
+ {'code': '300008', 'name': '新能源汽车', 'stocks': 603}
+ ]
+ ```
+
+ Returns:
+ [{'code': '301505', 'name': '医疗器械概念'}]
+ """
+
+ if not security:
+ return []
+
+ rsp = await cls._rpc_call(
+ "info_by_sec", {"board_type": _btype.value, "security": security}
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
+ @classmethod
+ async def board_filter_members(
+ cls,
+ included: List[str],
+ excluded: List[str] = [],
+ _btype: BoardType = BoardType.CONCEPT,
+ ) -> List:
+ """根据板块名筛选股票,参数为include, exclude
+
+ Fixme:
+ this function doesn't work
+ Raise status 500
+
+ Returns:
+ [['300181', '佐力药业'], ['600056', '中国医药']]
+ """
+ if not included:
+ return []
+ if excluded is None:
+ excluded = []
+
+ rsp = await cls._rpc_call(
+ "board_filter_members",
+ {
+ "board_type": _btype.value,
+ "include_boards": included,
+ "exclude_boards": excluded,
+ },
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
+ @classmethod
+ async def new_concept_boards(cls, days: int = 10):
+ raise NotImplementedError("not ready")
+
+ @classmethod
+ async def latest_concept_boards(n: int = 3):
+ raise NotImplementedError("not ready")
+
+ @classmethod
+ async def new_concept_members(days: int = 10, prot: int = None):
+ raise NotImplementedError("not ready")
+
+ @classmethod
+ async def board_filter(
+ cls, industry=None, with_concepts: Optional[List[str]] = None, without=[]
+ ):
+ raise NotImplementedError("not ready")
+
+ @classmethod
+ async def save_bars(cls, bars):
+ client = get_influx_client()
+
+ logger.info(
+ "persisting bars to influxdb: %s, %d secs", cls.measurement, len(bars)
+ )
+ await client.save(bars, cls.measurement, tag_keys=["code"], time_key="frame")
+ return True
+
+ @classmethod
+ async def get_last_date_of_bars(cls, code: str):
+ # 行业板块回溯1年的数据,概念板块只取当年的数据
+ code = f"{code}.THS"
+
+ client = get_influx_client()
+
+ now = datetime.datetime.now()
+ dt_end = tf.day_shift(now, 0)
+ # 250 + 60: 可以得到60个MA250的点, 默认K线图120个节点
+ dt_start = tf.day_shift(now, -310)
+
+ flux = (
+ Flux()
+ .measurement(cls.measurement)
+ .range(dt_start, dt_end)
+ .bucket(client._bucket)
+ .tags({"code": code})
+ )
+
+ data = await client.query(flux)
+ if len(data) == 2: # \r\n
+ return dt_start
+ ds = DataframeDeserializer(
+ sort_values="_time", usecols=["_time"], time_col="_time", engine="c"
+ )
+ bars = ds(data)
+ secs = bars.to_records(index=False).astype("datetime64[s]")
+
+ _dt = secs[-1].item()
+ return _dt.date()
+
+ @classmethod
+ async def get_bars_in_range(
+ cls, code: str, start: Frame, end: Frame = None
+ ) -> BarsArray:
+ """从持久化数据库中获取介于[`start`, `end`]间的行情记录
+
+ Examples:
+ ```python
+ start = datetime.date(2022, 9, 1) # 起始时间, 可修改
+ end = datetime.date(2023, 3, 1) # 截止时间, 可修改
+ board_code = '881128' # 汽车服务, 可修改
+ bars = await Board.get_bars_in_range(board_code, start, end)
+ bars[-3:] # 打印后3条数据
+
+ # prints:
+ rec.array([
+ ('2023-02-27T00:00:00', 1117.748, 1124.364, 1108.741, 1109.525, 1.77208600e+08, 1.13933095e+09, 1.),
+ ('2023-02-28T00:00:00', 1112.246, 1119.568, 1109.827, 1113.43 , 1.32828124e+08, 6.65160380e+08, 1.),
+ ('2023-03-01T00:00:00', 1122.233, 1123.493, 1116.62 , 1123.274, 7.21718910e+07, 3.71172850e+08, 1.)
+ ],
+ dtype=[('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])
+ ```
+ Args:
+ code: 板块代码(概念、行业)
+ start: 起始时间
+ end: 结束时间,如果未指明,则取当前时间
+
+ Returns:
+ 返回dtype为`coretypes.bars_dtype`的一维numpy数组。
+ """
+ end = end or datetime.datetime.now()
+ code = f"{code}.THS"
+
+ keep_cols = ["_time"] + list(bars_cols[1:])
+
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(start, end)
+ .measurement(cls.measurement)
+ .fields(keep_cols)
+ .tags({"code": code})
+ )
+
+ serializer = DataframeDeserializer(
+ encoding="utf-8",
+ names=[
+ "_",
+ "table",
+ "result",
+ "frame",
+ "code",
+ "amount",
+ "close",
+ "factor",
+ "high",
+ "low",
+ "open",
+ "volume",
+ ],
+ engine="c",
+ skiprows=0,
+ header=0,
+ usecols=bars_cols,
+ parse_dates=["frame"],
+ )
+
+ client = get_influx_client()
+ result = await client.query(flux, serializer)
+ return result.to_records(index=False).astype(bars_dtype)
+
board_filter_members(included, excluded=[], _btype=<BoardType.CONCEPT: 'concept'>)
+
+
+ async
+ classmethod
+
+
+¶根据板块名筛选股票,参数为include, exclude
+Fixme
+this function doesn't work +Raise status 500
+Returns:
+Type | +Description | +
---|---|
List |
+ [['300181', '佐力药业'], ['600056', '中国医药']] |
+
omicron/models/board.py
@classmethod
+async def board_filter_members(
+ cls,
+ included: List[str],
+ excluded: List[str] = [],
+ _btype: BoardType = BoardType.CONCEPT,
+) -> List:
+ """根据板块名筛选股票,参数为include, exclude
+
+ Fixme:
+ this function doesn't work
+ Raise status 500
+
+ Returns:
+ [['300181', '佐力药业'], ['600056', '中国医药']]
+ """
+ if not included:
+ return []
+ if excluded is None:
+ excluded = []
+
+ rsp = await cls._rpc_call(
+ "board_filter_members",
+ {
+ "board_type": _btype.value,
+ "include_boards": included,
+ "exclude_boards": excluded,
+ },
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
board_info_by_id(board_id, full_mode=False)
+
+
+ async
+ classmethod
+
+
+¶通过板块代码查询板块信息(名字,成员数目或清单)
+ +Examples:
+ +1 +2 +3 +4 +5 +6 |
|
Returns:
+Type | +Description | +
---|---|
{'code' |
+ '301505', 'name': '医疗器械概念', 'stocks': 242} +or |
+
omicron/models/board.py
@classmethod
+async def board_info_by_id(cls, board_id: str, full_mode: bool = False) -> dict:
+ """通过板块代码查询板块信息(名字,成员数目或清单)
+
+ Examples:
+ ```python
+ board_code = '881128' # 汽车服务 可自行修改
+ board_info = await Board.board_info_by_id(board_code)
+ print(board_info) # 字典形式
+
+ # returns
+ {'code': '881128', 'name': '汽车服务', 'stocks': 14}
+ ```
+
+ Returns:
+ {'code': '301505', 'name': '医疗器械概念', 'stocks': 242}
+ or
+ {'code': '301505', 'name': '医疗器械概念', 'stocks': [['300916', '朗特智能'], ['300760', '迈瑞医疗']]}
+ """
+
+ if not board_id:
+ return {}
+ if board_id[0] == "3":
+ _btype = BoardType.CONCEPT
+ else:
+ _btype = BoardType.INDUSTRY
+
+ _mode = 0
+ if full_mode: # 转换bool类型
+ _mode = 1
+
+ rsp = await cls._rpc_call(
+ "info",
+ {"board_type": _btype.value, "board_id": board_id, "fullmode": _mode},
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
board_info_by_security(security, _btype=<BoardType.CONCEPT: 'concept'>)
+
+
+ async
+ classmethod
+
+
+¶获取股票所在板块信息:名称,代码
+ +Examples:
+ +1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 |
|
Returns:
+Type | +Description | +
---|---|
[{'code' |
+ '301505', 'name': '医疗器械概念'}] |
+
omicron/models/board.py
@classmethod
+async def board_info_by_security(
+ cls, security: str, _btype: BoardType = BoardType.CONCEPT
+) -> List[dict]:
+ """获取股票所在板块信息:名称,代码
+
+ Examples:
+ ```python
+ stock_code = '002236' # 大华股份,股票代码不带字母后缀
+ stock_in_board = await Board.board_info_by_security(stock_code, _btype=BoardType.CONCEPT)
+ print(stock_in_board)
+
+ # returns:
+ [
+ {'code': '301715', 'name': '证金持股', 'stocks': 208},
+ {'code': '308870', 'name': '数字经济', 'stocks': 195},
+ {'code': '308642', 'name': '数据中心', 'stocks': 188},
+ ...,
+ {'code': '300008', 'name': '新能源汽车', 'stocks': 603}
+ ]
+ ```
+
+ Returns:
+ [{'code': '301505', 'name': '医疗器械概念'}]
+ """
+
+ if not security:
+ return []
+
+ rsp = await cls._rpc_call(
+ "info_by_sec", {"board_type": _btype.value, "security": security}
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
board_list(_btype=<BoardType.CONCEPT: 'concept'>)
+
+
+ async
+ classmethod
+
+
+¶获取板块列表
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
_btype |
+ BoardType |
+ 板块类别,可选值 |
+ <BoardType.CONCEPT: 'concept'> |
+
Returns:
+Type | +Description | +||
---|---|---|---|
List[List] |
+ 板块列表。每一个子元素仍为一个列表,由板块代码(str), 板块名称(str)和成员数组成。示例: +
|
+
omicron/models/board.py
@classmethod
+async def board_list(cls, _btype: BoardType = BoardType.CONCEPT) -> List[List]:
+ """获取板块列表
+
+ Args:
+ _btype: 板块类别,可选值`BoardType.CONCEPT`和`BoardType.INDUSTRY`.
+
+ Returns:
+ 板块列表。每一个子元素仍为一个列表,由板块代码(str), 板块名称(str)和成员数组成。示例:
+ ```
+ [
+ ['881101', '种植业与林业', 24],
+ ['881102', '养殖业', 27],
+ ['881103', '农产品加工', 41],
+ ['881104', '农业服务', 16],
+ ]
+ ```
+ """
+ rsp = await cls._rpc_call("board_list", {"board_type": _btype.value})
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
fuzzy_match_board_name(pattern, _btype=<BoardType.CONCEPT: 'concept'>)
+
+
+ async
+ classmethod
+
+
+¶模糊查询板块代码的名字
+ +Examples:
+ +1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
pattern |
+ str |
+ 待查询模式串 |
+ required | +
_btype |
+ BoardType |
+ 查询类型 |
+ <BoardType.CONCEPT: 'concept'> |
+
Returns:
+Type | +Description | +
---|---|
包含以下key的dict |
+ code(板块代码), name(板块名), stocks(股票数) |
+
omicron/models/board.py
@classmethod
+async def fuzzy_match_board_name(
+ cls, pattern: str, _btype: BoardType = BoardType.CONCEPT
+) -> dict:
+ """模糊查询板块代码的名字
+
+ Examples:
+ ```python
+ await Board.fuzzy_match_board_name("汽车", BoardType.INDUSTRY)
+
+ # returns:
+ [
+ '881125 汽车整车',
+ '881126 汽车零部件',
+ '881127 非汽车交运',
+ '881128 汽车服务',
+ '884107 汽车服务Ⅲ',
+ '884194 汽车零部件Ⅲ'
+ ]
+ ```
+ Args:
+ pattern: 待查询模式串
+ _btype: 查询类型
+
+ Returns:
+ 包含以下key的dict: code(板块代码), name(板块名), stocks(股票数)
+ """
+ if not pattern:
+ return []
+
+ rsp = await cls._rpc_call(
+ "fuzzy_match_name", {"board_type": _btype.value, "pattern": pattern}
+ )
+ if rsp["rc"] != 200:
+ return {"status": 500, "msg": "httpx RPC call failed"}
+
+ return rsp["data"]
+
get_bars_in_range(code, start, end=None)
+
+
+ async
+ classmethod
+
+
+¶从持久化数据库中获取介于[start
, end
]间的行情记录
Examples:
+ +1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ 板块代码(概念、行业) |
+ required | +
start |
+ Union[datetime.date, datetime.datetime] |
+ 起始时间 |
+ required | +
end |
+ Union[datetime.date, datetime.datetime] |
+ 结束时间,如果未指明,则取当前时间 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]] |
+ 返回dtype为 |
+
omicron/models/board.py
@classmethod
+async def get_bars_in_range(
+ cls, code: str, start: Frame, end: Frame = None
+) -> BarsArray:
+ """从持久化数据库中获取介于[`start`, `end`]间的行情记录
+
+ Examples:
+ ```python
+ start = datetime.date(2022, 9, 1) # 起始时间, 可修改
+ end = datetime.date(2023, 3, 1) # 截止时间, 可修改
+ board_code = '881128' # 汽车服务, 可修改
+ bars = await Board.get_bars_in_range(board_code, start, end)
+ bars[-3:] # 打印后3条数据
+
+ # prints:
+ rec.array([
+ ('2023-02-27T00:00:00', 1117.748, 1124.364, 1108.741, 1109.525, 1.77208600e+08, 1.13933095e+09, 1.),
+ ('2023-02-28T00:00:00', 1112.246, 1119.568, 1109.827, 1113.43 , 1.32828124e+08, 6.65160380e+08, 1.),
+ ('2023-03-01T00:00:00', 1122.233, 1123.493, 1116.62 , 1123.274, 7.21718910e+07, 3.71172850e+08, 1.)
+ ],
+ dtype=[('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])
+ ```
+ Args:
+ code: 板块代码(概念、行业)
+ start: 起始时间
+ end: 结束时间,如果未指明,则取当前时间
+
+ Returns:
+ 返回dtype为`coretypes.bars_dtype`的一维numpy数组。
+ """
+ end = end or datetime.datetime.now()
+ code = f"{code}.THS"
+
+ keep_cols = ["_time"] + list(bars_cols[1:])
+
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(start, end)
+ .measurement(cls.measurement)
+ .fields(keep_cols)
+ .tags({"code": code})
+ )
+
+ serializer = DataframeDeserializer(
+ encoding="utf-8",
+ names=[
+ "_",
+ "table",
+ "result",
+ "frame",
+ "code",
+ "amount",
+ "close",
+ "factor",
+ "high",
+ "low",
+ "open",
+ "volume",
+ ],
+ engine="c",
+ skiprows=0,
+ header=0,
+ usecols=bars_cols,
+ parse_dates=["frame"],
+ )
+
+ client = get_influx_client()
+ result = await client.query(flux, serializer)
+ return result.to_records(index=False).astype(bars_dtype)
+
+BoardType (Enum)
+
+
+
+
+¶An enumeration.
+ +omicron/models/board.py
class BoardType(Enum):
+ INDUSTRY = "industry"
+ CONCEPT = "concept"
+
Helper functions for building flux query expression
+ +omicron/dal/influx/flux.py
class Flux(object):
+ """Helper functions for building flux query expression"""
+
+ EPOCH_START = datetime.datetime(1970, 1, 1, 0, 0, 0)
+
+ def __init__(self, auto_pivot=True, no_sys_cols=True):
+ """初始化Flux对象
+
+ Args:
+ auto_pivot : 是否自动将查询列字段组装成行. Defaults to True.
+ no_sys_cols: 是否自动将系统字段删除. Defaults to True.请参考[drop_sys_cols][omicron.dal.influx.flux.Flux.drop_sys_cols]
+ """
+ self._cols = None
+ self.expressions = defaultdict(list)
+ self._auto_pivot = auto_pivot
+ self._last_n = None
+ self.no_sys_cols = no_sys_cols
+
+ def __str__(self):
+ return self._compose()
+
+ def __repr__(self) -> str:
+ return f"<{self.__class__.__name__}>:\n{self._compose()}"
+
+ def _compose(self):
+ """将所有表达式合并为一个表达式"""
+ if not all(
+ [
+ "bucket" in self.expressions,
+ "measurement" in self.expressions,
+ "range" in self.expressions,
+ ]
+ ):
+ raise AssertionError("bucket, measurement and range must be set")
+
+ expr = [self.expressions[k] for k in ("bucket", "range", "measurement")]
+
+ if self.expressions.get("tags"):
+ expr.append(self.expressions["tags"])
+
+ if self.expressions.get("fields"):
+ expr.append(self.expressions["fields"])
+
+ if "drop" not in self.expressions and self.no_sys_cols:
+ self.drop_sys_cols()
+
+ if self.expressions.get("drop"):
+ expr.append(self.expressions["drop"])
+
+ if self._auto_pivot and "pivot" not in self.expressions:
+ self.pivot()
+
+ if self.expressions.get("pivot"):
+ expr.append(self.expressions["pivot"])
+
+ if self.expressions.get("group"):
+ expr.append(self.expressions["group"])
+
+ if self.expressions.get("sort"):
+ expr.append(self.expressions["sort"])
+
+ if self.expressions.get("limit"):
+ expr.append(self.expressions["limit"])
+
+ # influxdb默认按升序排列,但last_n查询的结果则必然是降序的,所以还需要再次排序
+ if self._last_n:
+ expr.append(
+ "\n".join(
+ [
+ f' |> top(n: {self._last_n}, columns: ["_time"])',
+ ' |> sort(columns: ["_time"], desc: false)',
+ ]
+ )
+ )
+
+ return "\n".join(expr)
+
+ def bucket(self, bucket: str) -> "Flux":
+ """add bucket to query expression
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一个source,如果表达式中已经指定了bucket,则抛出异常
+
+ Returns:
+ Flux对象
+ """
+ if "bucket" in self.expressions:
+ raise DuplicateOperationError("bucket has been set")
+
+ self.expressions["bucket"] = f'from(bucket: "{bucket}")'
+
+ return self
+
+ def measurement(self, measurement: str) -> "Flux":
+ """add measurement filter to query
+
+ Raises:
+ DuplicateOperationError: 一次查询中只允许指定一个measurement, 如果表达式中已经存在measurement, 则抛出异常
+
+ Returns:
+ Flux对象自身,以便进行管道操作
+ """
+ if "measurement" in self.expressions:
+ raise DuplicateOperationError("measurement has been set")
+
+ self.expressions[
+ "measurement"
+ ] = f' |> filter(fn: (r) => r["_measurement"] == "{measurement}")'
+
+ return self
+
+ def range(
+ self, start: Frame, end: Frame, right_close=True, precision="s"
+ ) -> "Flux":
+ """添加时间范围过滤
+
+ 必须指定的查询条件,否则influxdb会报unbound查询错,因为这种情况下,返回的数据量将非常大。
+
+ 在格式化时间时,需要根据`precision`生成时间字符串。在向Influxdb发送请求时,应该注意查询参数中指定的时间精度与这里使用的保持一致。
+
+ Influxdb的查询结果默认不包含结束时间,当`right_close`指定为True时,我们将根据指定的精度修改`end`时间,使之仅比`end`多一个时间单位,从而保证查询结果会包含`end`。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一次时间范围,如果range表达式已经存在,则抛出异常
+ Args:
+ start: 开始时间
+ end: 结束时间
+ right_close: 查询结果是否包含结束时间。
+ precision: 时间精度,默认为秒。
+
+ Returns:
+ Flux对象,以支持管道操作
+ """
+ if "range" in self.expressions:
+ raise DuplicateOperationError("range has been set")
+
+ if precision not in ["s", "ms", "us"]:
+ raise AssertionError("precision must be 's', 'ms' or 'us'")
+
+ end = self.format_time(end, precision, right_close)
+ start = self.format_time(start, precision)
+
+ self.expressions["range"] = f" |> range(start: {start}, stop: {end})"
+ return self
+
+ def limit(self, limit: int) -> "Flux":
+ """添加返回记录数限制
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一次limit,如果limit表达式已经存在,则抛出异常
+
+ Args:
+ limit: 返回记录数限制
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "limit" in self.expressions:
+ raise DuplicateOperationError("limit has been set")
+
+ self.expressions["limit"] = " |> limit(n: %d)" % limit
+ return self
+
+ @classmethod
+ def to_timestamp(cls, tm: Frame, precision: str = "s") -> int:
+ """将时间根据精度转换为unix时间戳
+
+ 在往influxdb写入数据时,line-protocol要求的时间戳为unix timestamp,并且与其精度对应。
+
+ influxdb始终使用UTC时间,因此,`tm`也必须已经转换成UTC时间。
+
+ Args:
+ tm: 时间
+ precision: 时间精度,默认为秒。
+
+ Returns:
+ 时间戳
+ """
+ if precision not in ["s", "ms", "us"]:
+ raise AssertionError("precision must be 's', 'ms' or 'us'")
+
+ # get int repr of tm, in seconds unit
+ if isinstance(tm, np.datetime64):
+ tm = tm.astype("datetime64[s]").astype("int")
+ elif isinstance(tm, datetime.datetime):
+ tm = tm.timestamp()
+ else:
+ tm = arrow.get(tm).timestamp()
+
+ return int(tm * 10 ** ({"s": 0, "ms": 3, "us": 6}[precision]))
+
+ @classmethod
+ def format_time(cls, tm: Frame, precision: str = "s", shift_forward=False) -> str:
+ """将时间转换成客户端对应的精度,并以 RFC3339 timestamps格式串(即influxdb要求的格式)返回。
+
+ 如果这个时间是作为查询的range中的结束时间使用时,由于influx查询的时间范围是左闭右开的,因此如果你需要查询的是一个闭区间,则需要将`end`的时间向前偏移一个精度。通过传入`shift_forward = True`可以完成这种转换。
+
+ Examples:
+ >>> # by default, the precision is seconds, and convert a date
+ >>> Flux.format_time(datetime.date(2019, 1, 1))
+ '2019-01-01T00:00:00Z'
+
+ >>> # set precision to ms, convert a time
+ >>> Flux.format_time(datetime.datetime(1978, 7, 8, 12, 34, 56, 123456), precision="ms")
+ '1978-07-08T12:34:56.123Z'
+
+ >>> # convert and forward shift
+ >>> Flux.format_time(datetime.date(1978, 7, 8), shift_forward = True)
+ '1978-07-08T00:00:01Z'
+
+ Args:
+ tm : 待格式化的时间
+ precision: 时间精度,可选值为:'s', 'ms', 'us'
+ shift_forward: 如果为True,则将end向前偏移一个精度
+
+ Returns:
+ 调整后符合influx时间规范的时间(字符串表示)
+ """
+ timespec = {"s": "seconds", "ms": "milliseconds", "us": "microseconds"}.get(
+ precision
+ )
+
+ if timespec is None:
+ raise ValueError(
+ f"precision must be one of 's', 'ms', 'us', but got {precision}"
+ )
+
+ tm = arrow.get(tm).naive
+
+ if shift_forward:
+ tm = tm + datetime.timedelta(**{timespec: 1})
+
+ return tm.isoformat(sep="T", timespec=timespec) + "Z"
+
+ def tags(self, tags: DefaultDict[str, List[str]]) -> "Flux":
+ """给查询添加tags过滤条件
+
+ 此查询条件为过滤条件,并非必须。如果查询中没有指定tags,则会返回所有记录。
+
+ 在实现上,既可以使用`contains`语法,也可以使用`or`语法(由于一条记录只能属于一个tag,所以,当指定多个tag进行查询时,它们之间的关系应该为`or`)。经验证,contains语法会始终先将所有符合条件的记录检索出来,再进行过滤。这样的效率比较低,特别是当tags的数量较少时,会远远比使用or语法慢。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许执行一次,如果tag filter表达式已经存在,则抛出异常
+
+ Args:
+ tags : tags是一个{tagname: Union[str,[tag_values]]}对象。
+
+ Examples:
+ >>> flux = Flux()
+ >>> flux.tags({"code": ["000001", "000002"], "name": ["浦发银行"]}).expressions["tags"]
+ ' |> filter(fn: (r) => r["code"] == "000001" or r["code"] == "000002" or r["name"] == "浦发银行")'
+
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "tags" in self.expressions:
+ raise DuplicateOperationError("tags has been set")
+
+ filters = []
+ for tag, values in tags.items():
+ assert (
+ isinstance(values, str) or len(values) > 0
+ ), f"tag {tag} should not be empty or None"
+ if isinstance(values, str):
+ values = [values]
+
+ for v in values:
+ filters.append(f'r["{tag}"] == "{v}"')
+
+ op_expression = " or ".join(filters)
+
+ self.expressions["tags"] = f" |> filter(fn: (r) => {op_expression})"
+
+ return self
+
+ def fields(self, fields: List, reserve_time_stamp: bool = True) -> "Flux":
+ """给查询添加field过滤条件
+
+ 此查询条件为过滤条件,用以指定哪些field会出现在查询结果中,并非必须。如果查询中没有指定tags,则会返回所有记录。
+
+ 由于一条记录只能属于一个_field,所以,当指定多个_field进行查询时,它们之间的关系应该为`or`。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许执行一次,如果filed filter表达式已经存在,则抛出异常
+ Args:
+ fields: 待查询的field列表
+ reserve_time_stamp: 是否保留时间戳`_time`,默认为True
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "fields" in self.expressions:
+ raise DuplicateOperationError("fields has been set")
+
+ self._cols = fields.copy()
+
+ if reserve_time_stamp and "_time" not in self._cols:
+ self._cols.append("_time")
+
+ self._cols = sorted(self._cols)
+
+ filters = [f'r["_field"] == "{name}"' for name in self._cols]
+
+ self.expressions["fields"] = f" |> filter(fn: (r) => {' or '.join(filters)})"
+
+ return self
+
+ def pivot(
+ self,
+ row_keys: List[str] = ["_time"],
+ column_keys=["_field"],
+ value_column: str = "_value",
+ ) -> "Flux":
+ """pivot用来将以列为单位的数据转换为以行为单位的数据
+
+ Flux查询返回的结果通常都是以列为单位的数据,增加本pivot条件后,结果将被转换成为以行为单位的数据再返回。
+
+ 这里实现的是measurement内的转换,请参考 [pivot](https://docs.influxdata.com/flux/v0.x/stdlib/universe/pivot/#align-fields-within-each-measurement-that-have-the-same-timestamp)
+
+
+ Args:
+ row_keys: 惟一确定输出中一行数据的列名字, 默认为["_time"]
+ column_keys: 列名称列表,默认为["_field"]
+ value_column: 值列名,默认为"_value"
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "pivot" in self.expressions:
+ raise DuplicateOperationError("pivot has been set")
+
+ columns = ",".join([f'"{name}"' for name in column_keys])
+ rowkeys = ",".join([f'"{name}"' for name in row_keys])
+
+ self.expressions[
+ "pivot"
+ ] = f' |> pivot(columnKey: [{columns}], rowKey: [{rowkeys}], valueColumn: "{value_column}")'
+
+ return self
+
+ def sort(self, by: List[str] = None, desc: bool = False) -> "Flux":
+ """按照指定的列进行排序
+
+ 根据[influxdb doc](https://docs.influxdata.com/influxdb/v2.0/query-data/flux/first-last/), 查询返回值默认地按时间排序。因此,如果仅仅是要求查询结果按时间排序,无须调用此API,但是,此API提供了按其它字段排序的能力。
+
+ 另外,在一个有5000多个tag,共返回1M条记录的测试中,测试验证返回记录确实按_time升序排列。
+
+ Args:
+ by: 指定排序的列名称列表
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "sort" in self.expressions:
+ raise DuplicateOperationError("sort has been set")
+
+ if by is None:
+ by = ["_value"]
+ if isinstance(by, str):
+ by = [by]
+
+ columns_ = ",".join([f'"{name}"' for name in by])
+
+ desc = "true" if desc else "false"
+ self.expressions["sort"] = f" |> sort(columns: [{columns_}], desc: {desc})"
+
+ return self
+
+ def group(self, by: Tuple[str]) -> "Flux":
+ """[summary]
+
+ Returns:
+ [description]
+ """
+ if "group" in self.expressions:
+ raise DuplicateOperationError("group has been set")
+
+ if isinstance(by, str):
+ by = [by]
+ cols = ",".join([f'"{col}"' for col in by])
+ self.expressions["group"] = f" |> group(columns: [{cols}])"
+
+ return self
+
+ def latest(self, n: int) -> "Flux":
+ """获取最后n条数据,按时间增序返回
+
+ Flux查询的增强功能,相当于top + sort + limit
+
+ Args:
+ n: 最后n条数据
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ assert "top" not in self.expressions, "top and last_n can not be used together"
+ assert (
+ "sort" not in self.expressions
+ ), "sort and last_n can not be used together"
+ assert (
+ "limit" not in self.expressions
+ ), "limit and last_n can not be used together"
+
+ self._last_n = n
+
+ return self
+
+ @property
+ def cols(self) -> List[str]:
+ """the columns or the return records
+
+ the implementation is buggy. Influx doesn't tell us in which order these columns are.
+
+
+ Returns:
+ the columns name of the return records
+ """
+ # fixme: if keep in expression, then return group key + tag key + value key
+ # if keep not in expression, then stream, table, _time, ...
+ return sorted(self._cols)
+
+ def delete(
+ self,
+ measurement: str,
+ stop: datetime.datetime,
+ tags: dict = {},
+ start: datetime.datetime = None,
+ precision: str = "s",
+ ) -> dict:
+ """构建删除语句。
+
+ according to [delete-predicate](https://docs.influxdata.com/influxdb/v2.1/reference/syntax/delete-predicate/), delete只支持AND逻辑操作,只支持“=”操作,不支持“!=”操作,可以使用任何字段或者tag,但不包括_time和_value字段。
+
+ 由于influxdb这一段文档不是很清楚,根据试验结果,目前仅支持按时间范围和tags进行删除较好。如果某个column的值类型是字符串,则也可以通过`tags`参数传入,匹配后删除。但如果传入了非字符串类型的column,则将得到无法预料的结果。
+
+ Args:
+ measurement : [description]
+ stop : [description]
+ tags : 按tags和匹配的值进行删除。传入的tags中,key为tag名称,value为tag要匹配的取值,可以为str或者List[str]。
+ start : 起始时间。如果省略,则使用EPOCH_START.
+ precision : 时间精度。可以为“s”,“ms”,“us”
+ Returns:
+ 删除语句
+ """
+ timespec = {"s": "seconds", "ms": "milliseconds", "us": "microseconds"}.get(
+ precision
+ )
+
+ if start is None:
+ start = self.EPOCH_START.isoformat(timespec=timespec) + "Z"
+
+ predicate = [f'_measurement="{measurement}"']
+ for key, value in tags.items():
+ if isinstance(value, list):
+ predicate.extend([f'{key} = "{v}"' for v in value])
+ else:
+ predicate.append(f'{key} = "{value}"')
+
+ command = {
+ "start": start,
+ "stop": f"{stop.isoformat(timespec=timespec)}Z",
+ "predicate": " AND ".join(predicate),
+ }
+
+ return command
+
+ def drop(self, cols: List[str]) -> "Flux":
+ """use this to drop columns before return result
+
+ Args:
+ cols : the name of columns to be dropped
+
+ Returns:
+ Flux object, to support pipe operation
+ """
+ if "drop" in self.expressions:
+ raise DuplicateOperationError("drop operation has been set already")
+
+ # add surrounding quotes
+ _cols = [f'"{c}"' for c in cols]
+ self.expressions["drop"] = f" |> drop(columns: [{','.join(_cols)}])"
+
+ return self
+
+ def drop_sys_cols(self, cols: List[str] = None) -> "Flux":
+ """use this to drop ["_start", "_stop", "_measurement"], plus columns specified in `cols`, before return query result
+
+ please be noticed, after drop sys columns, there's still two sys columns left, which is "_time" and "table", and "_time" should usually be kept, "table" is one we're not able to removed. If you don't like _time in return result, you can specify it in `cols` parameter.
+
+ Args:
+ cols : the extra columns to be dropped
+
+ Returns:
+ Flux query object
+ """
+ _cols = ["_start", "_stop", "_measurement"]
+ if cols is not None:
+ _cols.extend(cols)
+
+ return self.drop(_cols)
+
cols: List[str]
+
+
+ property
+ readonly
+
+
+¶the columns or the return records
+the implementation is buggy. Influx doesn't tell us in which order these columns are.
+ +Returns:
+Type | +Description | +
---|---|
List[str] |
+ the columns name of the return records |
+
__init__(self, auto_pivot=True, no_sys_cols=True)
+
+
+ special
+
+
+¶初始化Flux对象
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
auto_pivot |
+ + | 是否自动将查询列字段组装成行. Defaults to True. |
+ True |
+
no_sys_cols |
+ + | 是否自动将系统字段删除. Defaults to True.请参考drop_sys_cols |
+ True |
+
omicron/dal/influx/flux.py
def __init__(self, auto_pivot=True, no_sys_cols=True):
+ """初始化Flux对象
+
+ Args:
+ auto_pivot : 是否自动将查询列字段组装成行. Defaults to True.
+ no_sys_cols: 是否自动将系统字段删除. Defaults to True.请参考[drop_sys_cols][omicron.dal.influx.flux.Flux.drop_sys_cols]
+ """
+ self._cols = None
+ self.expressions = defaultdict(list)
+ self._auto_pivot = auto_pivot
+ self._last_n = None
+ self.no_sys_cols = no_sys_cols
+
bucket(self, bucket)
+
+
+¶add bucket to query expression
+ +Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一个查询中只允许指定一个source,如果表达式中已经指定了bucket,则抛出异常 |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象 |
+
omicron/dal/influx/flux.py
def bucket(self, bucket: str) -> "Flux":
+ """add bucket to query expression
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一个source,如果表达式中已经指定了bucket,则抛出异常
+
+ Returns:
+ Flux对象
+ """
+ if "bucket" in self.expressions:
+ raise DuplicateOperationError("bucket has been set")
+
+ self.expressions["bucket"] = f'from(bucket: "{bucket}")'
+
+ return self
+
delete(self, measurement, stop, tags={}, start=None, precision='s')
+
+
+¶构建删除语句。
+according to delete-predicate, delete只支持AND逻辑操作,只支持“=”操作,不支持“!=”操作,可以使用任何字段或者tag,但不包括_time和_value字段。
+由于influxdb这一段文档不是很清楚,根据试验结果,目前仅支持按时间范围和tags进行删除较好。如果某个column的值类型是字符串,则也可以通过tags
参数传入,匹配后删除。但如果传入了非字符串类型的column,则将得到无法预料的结果。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
measurement |
+ + | [description] |
+ required | +
stop |
+ + | [description] |
+ required | +
tags |
+ + | 按tags和匹配的值进行删除。传入的tags中,key为tag名称,value为tag要匹配的取值,可以为str或者List[str]。 |
+ {} |
+
start |
+ + | 起始时间。如果省略,则使用EPOCH_START. |
+ None |
+
precision |
+ + | 时间精度。可以为“s”,“ms”,“us” |
+ 's' |
+
Returns:
+Type | +Description | +
---|---|
dict |
+ 删除语句 |
+
omicron/dal/influx/flux.py
def delete(
+ self,
+ measurement: str,
+ stop: datetime.datetime,
+ tags: dict = {},
+ start: datetime.datetime = None,
+ precision: str = "s",
+) -> dict:
+ """构建删除语句。
+
+ according to [delete-predicate](https://docs.influxdata.com/influxdb/v2.1/reference/syntax/delete-predicate/), delete只支持AND逻辑操作,只支持“=”操作,不支持“!=”操作,可以使用任何字段或者tag,但不包括_time和_value字段。
+
+ 由于influxdb这一段文档不是很清楚,根据试验结果,目前仅支持按时间范围和tags进行删除较好。如果某个column的值类型是字符串,则也可以通过`tags`参数传入,匹配后删除。但如果传入了非字符串类型的column,则将得到无法预料的结果。
+
+ Args:
+ measurement : [description]
+ stop : [description]
+ tags : 按tags和匹配的值进行删除。传入的tags中,key为tag名称,value为tag要匹配的取值,可以为str或者List[str]。
+ start : 起始时间。如果省略,则使用EPOCH_START.
+ precision : 时间精度。可以为“s”,“ms”,“us”
+ Returns:
+ 删除语句
+ """
+ timespec = {"s": "seconds", "ms": "milliseconds", "us": "microseconds"}.get(
+ precision
+ )
+
+ if start is None:
+ start = self.EPOCH_START.isoformat(timespec=timespec) + "Z"
+
+ predicate = [f'_measurement="{measurement}"']
+ for key, value in tags.items():
+ if isinstance(value, list):
+ predicate.extend([f'{key} = "{v}"' for v in value])
+ else:
+ predicate.append(f'{key} = "{value}"')
+
+ command = {
+ "start": start,
+ "stop": f"{stop.isoformat(timespec=timespec)}Z",
+ "predicate": " AND ".join(predicate),
+ }
+
+ return command
+
drop(self, cols)
+
+
+¶use this to drop columns before return result
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
cols |
+ + | the name of columns to be dropped |
+ required | +
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux object, to support pipe operation |
+
omicron/dal/influx/flux.py
def drop(self, cols: List[str]) -> "Flux":
+ """use this to drop columns before return result
+
+ Args:
+ cols : the name of columns to be dropped
+
+ Returns:
+ Flux object, to support pipe operation
+ """
+ if "drop" in self.expressions:
+ raise DuplicateOperationError("drop operation has been set already")
+
+ # add surrounding quotes
+ _cols = [f'"{c}"' for c in cols]
+ self.expressions["drop"] = f" |> drop(columns: [{','.join(_cols)}])"
+
+ return self
+
drop_sys_cols(self, cols=None)
+
+
+¶use this to drop ["_start", "_stop", "_measurement"], plus columns specified in cols
, before return query result
please be noticed, after drop sys columns, there's still two sys columns left, which is "_time" and "table", and "_time" should usually be kept, "table" is one we're not able to removed. If you don't like _time in return result, you can specify it in cols
parameter.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
cols |
+ + | the extra columns to be dropped |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux query object |
+
omicron/dal/influx/flux.py
def drop_sys_cols(self, cols: List[str] = None) -> "Flux":
+ """use this to drop ["_start", "_stop", "_measurement"], plus columns specified in `cols`, before return query result
+
+ please be noticed, after drop sys columns, there's still two sys columns left, which is "_time" and "table", and "_time" should usually be kept, "table" is one we're not able to removed. If you don't like _time in return result, you can specify it in `cols` parameter.
+
+ Args:
+ cols : the extra columns to be dropped
+
+ Returns:
+ Flux query object
+ """
+ _cols = ["_start", "_stop", "_measurement"]
+ if cols is not None:
+ _cols.extend(cols)
+
+ return self.drop(_cols)
+
fields(self, fields, reserve_time_stamp=True)
+
+
+¶给查询添加field过滤条件
+此查询条件为过滤条件,用以指定哪些field会出现在查询结果中,并非必须。如果查询中没有指定tags,则会返回所有记录。
+由于一条记录只能属于一个_field,所以,当指定多个_field进行查询时,它们之间的关系应该为or
。
Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一个查询中只允许执行一次,如果filed filter表达式已经存在,则抛出异常 |
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
fields |
+ List |
+ 待查询的field列表 |
+ required | +
reserve_time_stamp |
+ bool |
+ 是否保留时间戳 |
+ True |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def fields(self, fields: List, reserve_time_stamp: bool = True) -> "Flux":
+ """给查询添加field过滤条件
+
+ 此查询条件为过滤条件,用以指定哪些field会出现在查询结果中,并非必须。如果查询中没有指定tags,则会返回所有记录。
+
+ 由于一条记录只能属于一个_field,所以,当指定多个_field进行查询时,它们之间的关系应该为`or`。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许执行一次,如果filed filter表达式已经存在,则抛出异常
+ Args:
+ fields: 待查询的field列表
+ reserve_time_stamp: 是否保留时间戳`_time`,默认为True
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "fields" in self.expressions:
+ raise DuplicateOperationError("fields has been set")
+
+ self._cols = fields.copy()
+
+ if reserve_time_stamp and "_time" not in self._cols:
+ self._cols.append("_time")
+
+ self._cols = sorted(self._cols)
+
+ filters = [f'r["_field"] == "{name}"' for name in self._cols]
+
+ self.expressions["fields"] = f" |> filter(fn: (r) => {' or '.join(filters)})"
+
+ return self
+
format_time(tm, precision='s', shift_forward=False)
+
+
+ classmethod
+
+
+¶将时间转换成客户端对应的精度,并以 RFC3339 timestamps格式串(即influxdb要求的格式)返回。
+如果这个时间是作为查询的range中的结束时间使用时,由于influx查询的时间范围是左闭右开的,因此如果你需要查询的是一个闭区间,则需要将end
的时间向前偏移一个精度。通过传入shift_forward = True
可以完成这种转换。
Examples:
+>>> # by default, the precision is seconds, and convert a date
+>>> Flux.format_time(datetime.date(2019, 1, 1))
+'2019-01-01T00:00:00Z'
+
>>> # set precision to ms, convert a time
+>>> Flux.format_time(datetime.datetime(1978, 7, 8, 12, 34, 56, 123456), precision="ms")
+'1978-07-08T12:34:56.123Z'
+
>>> # convert and forward shift
+>>> Flux.format_time(datetime.date(1978, 7, 8), shift_forward = True)
+'1978-07-08T00:00:01Z'
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | 待格式化的时间 |
+ required | +
precision |
+ str |
+ 时间精度,可选值为:'s', 'ms', 'us' |
+ 's' |
+
shift_forward |
+ + | 如果为True,则将end向前偏移一个精度 |
+ False |
+
Returns:
+Type | +Description | +
---|---|
str |
+ 调整后符合influx时间规范的时间(字符串表示) |
+
omicron/dal/influx/flux.py
@classmethod
+def format_time(cls, tm: Frame, precision: str = "s", shift_forward=False) -> str:
+ """将时间转换成客户端对应的精度,并以 RFC3339 timestamps格式串(即influxdb要求的格式)返回。
+
+ 如果这个时间是作为查询的range中的结束时间使用时,由于influx查询的时间范围是左闭右开的,因此如果你需要查询的是一个闭区间,则需要将`end`的时间向前偏移一个精度。通过传入`shift_forward = True`可以完成这种转换。
+
+ Examples:
+ >>> # by default, the precision is seconds, and convert a date
+ >>> Flux.format_time(datetime.date(2019, 1, 1))
+ '2019-01-01T00:00:00Z'
+
+ >>> # set precision to ms, convert a time
+ >>> Flux.format_time(datetime.datetime(1978, 7, 8, 12, 34, 56, 123456), precision="ms")
+ '1978-07-08T12:34:56.123Z'
+
+ >>> # convert and forward shift
+ >>> Flux.format_time(datetime.date(1978, 7, 8), shift_forward = True)
+ '1978-07-08T00:00:01Z'
+
+ Args:
+ tm : 待格式化的时间
+ precision: 时间精度,可选值为:'s', 'ms', 'us'
+ shift_forward: 如果为True,则将end向前偏移一个精度
+
+ Returns:
+ 调整后符合influx时间规范的时间(字符串表示)
+ """
+ timespec = {"s": "seconds", "ms": "milliseconds", "us": "microseconds"}.get(
+ precision
+ )
+
+ if timespec is None:
+ raise ValueError(
+ f"precision must be one of 's', 'ms', 'us', but got {precision}"
+ )
+
+ tm = arrow.get(tm).naive
+
+ if shift_forward:
+ tm = tm + datetime.timedelta(**{timespec: 1})
+
+ return tm.isoformat(sep="T", timespec=timespec) + "Z"
+
group(self, by)
+
+
+¶[summary]
+ +Returns:
+Type | +Description | +
---|---|
Flux |
+ [description] |
+
omicron/dal/influx/flux.py
def group(self, by: Tuple[str]) -> "Flux":
+ """[summary]
+
+ Returns:
+ [description]
+ """
+ if "group" in self.expressions:
+ raise DuplicateOperationError("group has been set")
+
+ if isinstance(by, str):
+ by = [by]
+ cols = ",".join([f'"{col}"' for col in by])
+ self.expressions["group"] = f" |> group(columns: [{cols}])"
+
+ return self
+
latest(self, n)
+
+
+¶获取最后n条数据,按时间增序返回
+Flux查询的增强功能,相当于top + sort + limit
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
n |
+ int |
+ 最后n条数据 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def latest(self, n: int) -> "Flux":
+ """获取最后n条数据,按时间增序返回
+
+ Flux查询的增强功能,相当于top + sort + limit
+
+ Args:
+ n: 最后n条数据
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ assert "top" not in self.expressions, "top and last_n can not be used together"
+ assert (
+ "sort" not in self.expressions
+ ), "sort and last_n can not be used together"
+ assert (
+ "limit" not in self.expressions
+ ), "limit and last_n can not be used together"
+
+ self._last_n = n
+
+ return self
+
limit(self, limit)
+
+
+¶添加返回记录数限制
+ +Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一个查询中只允许指定一次limit,如果limit表达式已经存在,则抛出异常 |
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
limit |
+ int |
+ 返回记录数限制 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def limit(self, limit: int) -> "Flux":
+ """添加返回记录数限制
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一次limit,如果limit表达式已经存在,则抛出异常
+
+ Args:
+ limit: 返回记录数限制
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "limit" in self.expressions:
+ raise DuplicateOperationError("limit has been set")
+
+ self.expressions["limit"] = " |> limit(n: %d)" % limit
+ return self
+
measurement(self, measurement)
+
+
+¶add measurement filter to query
+ +Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一次查询中只允许指定一个measurement, 如果表达式中已经存在measurement, 则抛出异常 |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象自身,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def measurement(self, measurement: str) -> "Flux":
+ """add measurement filter to query
+
+ Raises:
+ DuplicateOperationError: 一次查询中只允许指定一个measurement, 如果表达式中已经存在measurement, 则抛出异常
+
+ Returns:
+ Flux对象自身,以便进行管道操作
+ """
+ if "measurement" in self.expressions:
+ raise DuplicateOperationError("measurement has been set")
+
+ self.expressions[
+ "measurement"
+ ] = f' |> filter(fn: (r) => r["_measurement"] == "{measurement}")'
+
+ return self
+
pivot(self, row_keys=['_time'], column_keys=['_field'], value_column='_value')
+
+
+¶pivot用来将以列为单位的数据转换为以行为单位的数据
+Flux查询返回的结果通常都是以列为单位的数据,增加本pivot条件后,结果将被转换成为以行为单位的数据再返回。
+这里实现的是measurement内的转换,请参考 pivot
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
row_keys |
+ List[str] |
+ 惟一确定输出中一行数据的列名字, 默认为["_time"] |
+ ['_time'] |
+
column_keys |
+ + | 列名称列表,默认为["_field"] |
+ ['_field'] |
+
value_column |
+ str |
+ 值列名,默认为"_value" |
+ '_value' |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def pivot(
+ self,
+ row_keys: List[str] = ["_time"],
+ column_keys=["_field"],
+ value_column: str = "_value",
+) -> "Flux":
+ """pivot用来将以列为单位的数据转换为以行为单位的数据
+
+ Flux查询返回的结果通常都是以列为单位的数据,增加本pivot条件后,结果将被转换成为以行为单位的数据再返回。
+
+ 这里实现的是measurement内的转换,请参考 [pivot](https://docs.influxdata.com/flux/v0.x/stdlib/universe/pivot/#align-fields-within-each-measurement-that-have-the-same-timestamp)
+
+
+ Args:
+ row_keys: 惟一确定输出中一行数据的列名字, 默认为["_time"]
+ column_keys: 列名称列表,默认为["_field"]
+ value_column: 值列名,默认为"_value"
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "pivot" in self.expressions:
+ raise DuplicateOperationError("pivot has been set")
+
+ columns = ",".join([f'"{name}"' for name in column_keys])
+ rowkeys = ",".join([f'"{name}"' for name in row_keys])
+
+ self.expressions[
+ "pivot"
+ ] = f' |> pivot(columnKey: [{columns}], rowKey: [{rowkeys}], valueColumn: "{value_column}")'
+
+ return self
+
range(self, start, end, right_close=True, precision='s')
+
+
+¶添加时间范围过滤
+必须指定的查询条件,否则influxdb会报unbound查询错,因为这种情况下,返回的数据量将非常大。
+在格式化时间时,需要根据precision
生成时间字符串。在向Influxdb发送请求时,应该注意查询参数中指定的时间精度与这里使用的保持一致。
Influxdb的查询结果默认不包含结束时间,当right_close
指定为True时,我们将根据指定的精度修改end
时间,使之仅比end
多一个时间单位,从而保证查询结果会包含end
。
Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一个查询中只允许指定一次时间范围,如果range表达式已经存在,则抛出异常 |
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ Union[datetime.date, datetime.datetime] |
+ 开始时间 |
+ required | +
end |
+ Union[datetime.date, datetime.datetime] |
+ 结束时间 |
+ required | +
right_close |
+ + | 查询结果是否包含结束时间。 |
+ True |
+
precision |
+ + | 时间精度,默认为秒。 |
+ 's' |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以支持管道操作 |
+
omicron/dal/influx/flux.py
def range(
+ self, start: Frame, end: Frame, right_close=True, precision="s"
+) -> "Flux":
+ """添加时间范围过滤
+
+ 必须指定的查询条件,否则influxdb会报unbound查询错,因为这种情况下,返回的数据量将非常大。
+
+ 在格式化时间时,需要根据`precision`生成时间字符串。在向Influxdb发送请求时,应该注意查询参数中指定的时间精度与这里使用的保持一致。
+
+ Influxdb的查询结果默认不包含结束时间,当`right_close`指定为True时,我们将根据指定的精度修改`end`时间,使之仅比`end`多一个时间单位,从而保证查询结果会包含`end`。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许指定一次时间范围,如果range表达式已经存在,则抛出异常
+ Args:
+ start: 开始时间
+ end: 结束时间
+ right_close: 查询结果是否包含结束时间。
+ precision: 时间精度,默认为秒。
+
+ Returns:
+ Flux对象,以支持管道操作
+ """
+ if "range" in self.expressions:
+ raise DuplicateOperationError("range has been set")
+
+ if precision not in ["s", "ms", "us"]:
+ raise AssertionError("precision must be 's', 'ms' or 'us'")
+
+ end = self.format_time(end, precision, right_close)
+ start = self.format_time(start, precision)
+
+ self.expressions["range"] = f" |> range(start: {start}, stop: {end})"
+ return self
+
sort(self, by=None, desc=False)
+
+
+¶按照指定的列进行排序
+根据influxdb doc, 查询返回值默认地按时间排序。因此,如果仅仅是要求查询结果按时间排序,无须调用此API,但是,此API提供了按其它字段排序的能力。
+另外,在一个有5000多个tag,共返回1M条记录的测试中,测试验证返回记录确实按_time升序排列。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
by |
+ List[str] |
+ 指定排序的列名称列表 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def sort(self, by: List[str] = None, desc: bool = False) -> "Flux":
+ """按照指定的列进行排序
+
+ 根据[influxdb doc](https://docs.influxdata.com/influxdb/v2.0/query-data/flux/first-last/), 查询返回值默认地按时间排序。因此,如果仅仅是要求查询结果按时间排序,无须调用此API,但是,此API提供了按其它字段排序的能力。
+
+ 另外,在一个有5000多个tag,共返回1M条记录的测试中,测试验证返回记录确实按_time升序排列。
+
+ Args:
+ by: 指定排序的列名称列表
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "sort" in self.expressions:
+ raise DuplicateOperationError("sort has been set")
+
+ if by is None:
+ by = ["_value"]
+ if isinstance(by, str):
+ by = [by]
+
+ columns_ = ",".join([f'"{name}"' for name in by])
+
+ desc = "true" if desc else "false"
+ self.expressions["sort"] = f" |> sort(columns: [{columns_}], desc: {desc})"
+
+ return self
+
tags(self, tags)
+
+
+¶给查询添加tags过滤条件
+此查询条件为过滤条件,并非必须。如果查询中没有指定tags,则会返回所有记录。
+在实现上,既可以使用contains
语法,也可以使用or
语法(由于一条记录只能属于一个tag,所以,当指定多个tag进行查询时,它们之间的关系应该为or
)。经验证,contains语法会始终先将所有符合条件的记录检索出来,再进行过滤。这样的效率比较低,特别是当tags的数量较少时,会远远比使用or语法慢。
Exceptions:
+Type | +Description | +
---|---|
DuplicateOperationError |
+ 一个查询中只允许执行一次,如果tag filter表达式已经存在,则抛出异常 |
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tags |
+ + | tags是一个{tagname: Union[str,[tag_values]]}对象。 |
+ required | +
Examples:
+>>> flux = Flux()
+>>> flux.tags({"code": ["000001", "000002"], "name": ["浦发银行"]}).expressions["tags"]
+' |> filter(fn: (r) => r["code"] == "000001" or r["code"] == "000002" or r["name"] == "浦发银行")'
+
Returns:
+Type | +Description | +
---|---|
Flux |
+ Flux对象,以便进行管道操作 |
+
omicron/dal/influx/flux.py
def tags(self, tags: DefaultDict[str, List[str]]) -> "Flux":
+ """给查询添加tags过滤条件
+
+ 此查询条件为过滤条件,并非必须。如果查询中没有指定tags,则会返回所有记录。
+
+ 在实现上,既可以使用`contains`语法,也可以使用`or`语法(由于一条记录只能属于一个tag,所以,当指定多个tag进行查询时,它们之间的关系应该为`or`)。经验证,contains语法会始终先将所有符合条件的记录检索出来,再进行过滤。这样的效率比较低,特别是当tags的数量较少时,会远远比使用or语法慢。
+
+ Raises:
+ DuplicateOperationError: 一个查询中只允许执行一次,如果tag filter表达式已经存在,则抛出异常
+
+ Args:
+ tags : tags是一个{tagname: Union[str,[tag_values]]}对象。
+
+ Examples:
+ >>> flux = Flux()
+ >>> flux.tags({"code": ["000001", "000002"], "name": ["浦发银行"]}).expressions["tags"]
+ ' |> filter(fn: (r) => r["code"] == "000001" or r["code"] == "000002" or r["name"] == "浦发银行")'
+
+
+ Returns:
+ Flux对象,以便进行管道操作
+ """
+ if "tags" in self.expressions:
+ raise DuplicateOperationError("tags has been set")
+
+ filters = []
+ for tag, values in tags.items():
+ assert (
+ isinstance(values, str) or len(values) > 0
+ ), f"tag {tag} should not be empty or None"
+ if isinstance(values, str):
+ values = [values]
+
+ for v in values:
+ filters.append(f'r["{tag}"] == "{v}"')
+
+ op_expression = " or ".join(filters)
+
+ self.expressions["tags"] = f" |> filter(fn: (r) => {op_expression})"
+
+ return self
+
to_timestamp(tm, precision='s')
+
+
+ classmethod
+
+
+¶将时间根据精度转换为unix时间戳
+在往influxdb写入数据时,line-protocol要求的时间戳为unix timestamp,并且与其精度对应。
+influxdb始终使用UTC时间,因此,tm
也必须已经转换成UTC时间。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ Union[datetime.date, datetime.datetime] |
+ 时间 |
+ required | +
precision |
+ str |
+ 时间精度,默认为秒。 |
+ 's' |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 时间戳 |
+
omicron/dal/influx/flux.py
@classmethod
+def to_timestamp(cls, tm: Frame, precision: str = "s") -> int:
+ """将时间根据精度转换为unix时间戳
+
+ 在往influxdb写入数据时,line-protocol要求的时间戳为unix timestamp,并且与其精度对应。
+
+ influxdb始终使用UTC时间,因此,`tm`也必须已经转换成UTC时间。
+
+ Args:
+ tm: 时间
+ precision: 时间精度,默认为秒。
+
+ Returns:
+ 时间戳
+ """
+ if precision not in ["s", "ms", "us"]:
+ raise AssertionError("precision must be 's', 'ms' or 'us'")
+
+ # get int repr of tm, in seconds unit
+ if isinstance(tm, np.datetime64):
+ tm = tm.astype("datetime64[s]").astype("int")
+ elif isinstance(tm, datetime.datetime):
+ tm = tm.timestamp()
+ else:
+ tm = arrow.get(tm).timestamp()
+
+ return int(tm * 10 ** ({"s": 0, "ms": 3, "us": 6}[precision]))
+
omicron/dal/influx/influxclient.py
class InfluxClient:
+ def __init__(
+ self,
+ url: str,
+ token: str,
+ bucket: str,
+ org: str = None,
+ enable_compress=False,
+ chunk_size: int = 5000,
+ precision: str = "s",
+ ):
+ """[summary]
+
+ Args:
+ url ([type]): [description]
+ token ([type]): [description]
+ bucket ([type]): [description]
+ org ([type], optional): [description]. Defaults to None.
+ enable_compress ([type], optional): [description]. Defaults to False.
+ chunk_size: number of lines to be saved in one request
+ precision: 支持的时间精度
+ """
+ self._url = url
+ self._bucket = bucket
+ self._enable_compress = enable_compress
+ self._org = org
+ self._org_id = None # 需要时通过查询获取,此后不再更新
+ self._token = token
+
+ # influxdb 2.0起支持的时间精度有:ns, us, ms, s。本客户端只支持s, ms和us
+ self._precision = precision.lower()
+ if self._precision not in ["s", "ms", "us"]: # pragma: no cover
+ raise ValueError("precision must be one of ['s', 'ms', 'us']")
+
+ self._chunk_size = chunk_size
+
+ # write
+ self._write_url = f"{self._url}/api/v2/write?org={self._org}&bucket={self._bucket}&precision={self._precision}"
+
+ self._write_headers = {
+ "Content-Type": "text/plain; charset=utf-8",
+ "Authorization": f"Token {token}",
+ "Accept": "application/json",
+ }
+
+ if self._enable_compress:
+ self._write_headers["Content-Encoding"] = "gzip"
+
+ self._query_url = f"{self._url}/api/v2/query?org={self._org}"
+ self._query_headers = {
+ "Authorization": f"Token {token}",
+ "Content-Type": "application/vnd.flux",
+ # influx查询结果格式,无论如何指定(或者不指定),在2.1中始终是csv格式
+ "Accept": "text/csv",
+ }
+
+ if self._enable_compress:
+ self._query_headers["Accept-Encoding"] = "gzip"
+
+ self._delete_url = (
+ f"{self._url}/api/v2/delete?org={self._org}&bucket={self._bucket}"
+ )
+ self._delete_headers = {
+ "Authorization": f"Token {token}",
+ "Content-Type": "application/json",
+ }
+
+ async def save(
+ self,
+ data: Union[np.ndarray, DataFrame],
+ measurement: str = None,
+ tag_keys: List[str] = [],
+ time_key: str = None,
+ global_tags: Dict = {},
+ chunk_size: int = None,
+ ) -> None:
+ """save `data` into influxdb
+
+ if `data` is a pandas.DataFrame or numy structured array, it will be converted to line protocol and saved. If `data` is str, use `write` method instead.
+
+ Args:
+ data: data to be saved
+ measurement: the name of measurement
+ tag_keys: which columns name will be used as tags
+ chunk_size: number of lines to be saved in one request. if it's -1, then all data will be written in one request. If it's None, then it will be set to `self._chunk_size`
+
+ Raises:
+ InfluxDBWriteError: if write failed
+
+ """
+ # todo: add more errors raise
+ if isinstance(data, DataFrame):
+ assert (
+ measurement is not None
+ ), "measurement must be specified when data is a DataFrame"
+
+ if tag_keys:
+ assert set(tag_keys) in set(
+ data.columns.tolist()
+ ), "tag_keys must be in data.columns"
+
+ serializer = DataframeSerializer(
+ data,
+ measurement,
+ time_key,
+ tag_keys,
+ global_tags,
+ precision=self._precision,
+ )
+ if chunk_size == -1:
+ chunk_size = len(data)
+
+ for lines in serializer.serialize(chunk_size or self._chunk_size):
+ await self.write(lines)
+ elif isinstance(data, np.ndarray):
+ assert (
+ measurement is not None
+ ), "measurement must be specified when data is a numpy array"
+ assert (
+ time_key is not None
+ ), "time_key must be specified when data is a numpy array"
+ serializer = NumpySerializer(
+ data,
+ measurement,
+ time_key,
+ tag_keys,
+ global_tags,
+ time_precision=self._precision,
+ )
+ if chunk_size == -1:
+ chunk_size = len(data)
+ for lines in serializer.serialize(chunk_size or self._chunk_size):
+ await self.write(lines)
+ else:
+ raise TypeError(
+ f"data must be pandas.DataFrame, numpy array, got {type(data)}"
+ )
+
+ async def write(self, line_protocol: str):
+ """将line-protocol数组写入influxdb
+
+ Args:
+ line_protocol: 待写入的数据,以line-protocol数组形式存在
+
+ """
+ # todo: add raise error declaration
+ if self._enable_compress:
+ line_protocol_ = gzip.compress(line_protocol.encode("utf-8"))
+ else:
+ line_protocol_ = line_protocol
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._write_url, data=line_protocol_, headers=self._write_headers
+ ) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb write error when processing: %s, err code: %s, message: %s",
+ {line_protocol[:100]},
+ err["code"],
+ err["message"],
+ )
+ logger.debug("data caused error:%s", line_protocol)
+ raise InfluxDBWriteError(
+ f"influxdb write failed, err: {err['message']}"
+ )
+
+ async def query(self, flux: Union[Flux, str], deserializer: Callable = None) -> Any:
+ """flux查询
+
+ flux查询结果是一个以annotated csv格式存储的数据,例如:
+ ```
+ ,result,table,_time,code,amount,close,factor,high,low,open,volume
+ ,_result,0,2019-01-01T00:00:00Z,000001.XSHE,100000000,5.15,1.23,5.2,5,5.1,1000000
+ ```
+
+ 上述`result`中,事先通过Flux.keep()限制了返回的字段为_time,code,amount,close,factor,high,low,open,volume。influxdb查询返回结果时,总是按照字段名称升序排列。此外,总是会额外地返回_result, table两个字段。
+
+ 如果传入了deserializer,则会调用deserializer将其解析成为python对象。否则,返回bytes数据。
+
+ Args:
+ flux: flux查询语句
+ deserializer: 反序列化函数
+
+ Returns:
+ 如果未提供反序列化函数,则返回结果为bytes array(如果指定了compress=True,返回结果为gzip解压缩后的bytes array),否则返回反序列化后的python对象
+ """
+ if isinstance(flux, Flux):
+ flux = str(flux)
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._query_url, data=flux, headers=self._query_headers
+ ) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ logger.warning(
+ f"influxdb query error: {err} when processing {flux[:500]}"
+ )
+ logger.debug("data caused error:%s", flux)
+ raise InfluxDBQueryError(
+ f"influxdb query failed, status code: {err['message']}"
+ )
+ else:
+ # auto-unzip
+ body = await resp.read()
+ if deserializer:
+ try:
+ return deserializer(body)
+ except Exception as e:
+ logger.exception(e)
+ logger.warning(
+ "failed to deserialize data: %s, the query is:%s",
+ body,
+ flux[:500],
+ )
+ raise
+ else:
+ return body
+
+ async def drop_measurement(self, measurement: str):
+ """从influxdb中删除一个measurement
+
+ 调用此方法后,实际上该measurement仍然存在,只是没有数据。
+
+ """
+ # todo: add raise error declaration
+ await self.delete(measurement, arrow.now().naive)
+
+ async def delete(
+ self,
+ measurement: str,
+ stop: datetime.datetime,
+ tags: Optional[Dict[str, str]] = {},
+ start: datetime.datetime = None,
+ precision: str = "s",
+ ):
+ """删除influxdb中指定时间段内的数据
+
+ 关于参数,请参见[Flux.delete][omicron.dal.influx.flux.Flux.delete]。
+
+ Args:
+ measurement: 指定measurement名字
+ stop: 待删除记录的结束时间
+ start: 待删除记录的开始时间,如果未指定,则使用EPOCH_START
+ tags: 按tag进行过滤的条件
+ precision: 用以格式化起始和结束时间。
+
+ Raises:
+ InfluxDeleteError: 如果删除失败,则抛出此异常
+ """
+ # todo: add raise error declaration
+ command = Flux().delete(
+ measurement, stop, tags, start=start, precision=precision
+ )
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._delete_url, data=json.dumps(command), headers=self._delete_headers
+ ) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb delete error: %s when processin command %s",
+ err["message"],
+ command,
+ )
+ raise InfluxDeleteError(
+ f"influxdb delete failed, status code: {err['message']}"
+ )
+
+ async def list_buckets(self) -> List[Dict]:
+ """列出influxdb中对应token能看到的所有的bucket
+
+ Returns:
+ list of buckets, each bucket is a dict with keys:
+ ```
+ id
+ orgID, a 16 bytes hex string
+ type, system or user
+ description
+ name
+ retentionRules
+ createdAt
+ updatedAt
+ links
+ labels
+ ```
+ """
+ url = f"{self._url}/api/v2/buckets"
+ headers = {"Authorization": f"Token {self._token}"}
+ async with ClientSession() as session:
+ async with session.get(url, headers=headers) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ raise InfluxSchemaError(
+ f"influxdb list bucket failed, status code: {err['message']}"
+ )
+ else:
+ return (await resp.json())["buckets"]
+
+ async def delete_bucket(self, bucket_id: str = None):
+ """删除influxdb中指定bucket
+
+ Args:
+ bucket_id: 指定bucket的id。如果为None,则会删除本client对应的bucket。
+ """
+ if bucket_id is None:
+ buckets = await self.list_buckets()
+ for bucket in buckets:
+ if bucket["type"] == "user" and bucket["name"] == self._bucket:
+ bucket_id = bucket["id"]
+ break
+ else:
+ raise BadParameterError(
+ "bucket_id is None, and we can't find bucket with name: %s"
+ % self._bucket
+ )
+
+ url = f"{self._url}/api/v2/buckets/{bucket_id}"
+ headers = {"Authorization": f"Token {self._token}"}
+ async with ClientSession() as session:
+ async with session.delete(url, headers=headers) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb delete bucket error: %s when processin command %s",
+ err["message"],
+ bucket_id,
+ )
+ raise InfluxSchemaError(
+ f"influxdb delete bucket failed, status code: {err['message']}"
+ )
+
+ async def create_bucket(
+ self, description=None, retention_rules: List[Dict] = None, org_id: str = None
+ ) -> str:
+ """创建influxdb中指定bucket
+
+ Args:
+ description: 指定bucket的描述
+ org_id: 指定bucket所属的组织id,如果未指定,则使用本client对应的组织id。
+
+ Raises:
+ InfluxSchemaError: 当influxdb返回错误时,比如重复创建bucket等,会抛出此异常
+ Returns:
+ 新创建的bucket的id
+ """
+ if org_id is None:
+ org_id = await self.query_org_id()
+
+ url = f"{self._url}/api/v2/buckets"
+ headers = {"Authorization": f"Token {self._token}"}
+ data = {
+ "name": self._bucket,
+ "orgID": org_id,
+ "description": description,
+ "retentionRules": retention_rules,
+ }
+ async with ClientSession() as session:
+ async with session.post(
+ url, data=json.dumps(data), headers=headers
+ ) as resp:
+ if resp.status != 201:
+ err = await resp.json()
+ logger.warning(
+ "influxdb create bucket error: %s when processin command %s",
+ err["message"],
+ data,
+ )
+ raise InfluxSchemaError(
+ f"influxdb create bucket failed, status code: {err['message']}"
+ )
+ else:
+ result = await resp.json()
+ return result["id"]
+
+ async def list_organizations(self, offset: int = 0, limit: int = 100) -> List[Dict]:
+ """列出本客户端允许查询的所组织
+
+ Args:
+ offset : 分页起点
+ limit : 每页size
+
+ Raises:
+ InfluxSchemaError: influxdb返回的错误
+
+ Returns:
+ list of organizations, each organization is a dict with keys:
+ ```
+ id : the id of the org
+ links
+ name : the name of the org
+ description
+ createdAt
+ updatedAt
+ ```
+ """
+ url = f"{self._url}/api/v2/orgs?offset={offset}&limit={limit}"
+ headers = {"Authorization": f"Token {self._token}"}
+
+ async with ClientSession() as session:
+ async with session.get(url, headers=headers) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ logger.warning("influxdb query orgs err: %s", err["message"])
+ raise InfluxSchemaError(
+ f"influxdb query orgs failed, status code: {err['message']}"
+ )
+ else:
+ return (await resp.json())["orgs"]
+
+ async def query_org_id(self, name: str = None) -> str:
+ """通过组织名查找组织id
+
+ 只能查的本客户端允许查询的组织。如果name未提供,则使用本客户端创建时传入的组织名。
+
+ Args:
+ name: 指定组织名
+
+ Returns:
+ 组织id
+ """
+ if name is None:
+ name = self._org
+ orgs = await self.list_organizations()
+ for org in orgs:
+ if org["name"] == name:
+ return org["id"]
+
+ raise BadParameterError(f"can't find org with name: {name}")
+
__init__(self, url, token, bucket, org=None, enable_compress=False, chunk_size=5000, precision='s')
+
+
+ special
+
+
+¶[summary]
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
url |
+ [type] |
+ [description] |
+ required | +
token |
+ [type] |
+ [description] |
+ required | +
bucket |
+ [type] |
+ [description] |
+ required | +
org |
+ [type] |
+ [description]. Defaults to None. |
+ None |
+
enable_compress |
+ [type] |
+ [description]. Defaults to False. |
+ False |
+
chunk_size |
+ int |
+ number of lines to be saved in one request |
+ 5000 |
+
precision |
+ str |
+ 支持的时间精度 |
+ 's' |
+
omicron/dal/influx/influxclient.py
def __init__(
+ self,
+ url: str,
+ token: str,
+ bucket: str,
+ org: str = None,
+ enable_compress=False,
+ chunk_size: int = 5000,
+ precision: str = "s",
+):
+ """[summary]
+
+ Args:
+ url ([type]): [description]
+ token ([type]): [description]
+ bucket ([type]): [description]
+ org ([type], optional): [description]. Defaults to None.
+ enable_compress ([type], optional): [description]. Defaults to False.
+ chunk_size: number of lines to be saved in one request
+ precision: 支持的时间精度
+ """
+ self._url = url
+ self._bucket = bucket
+ self._enable_compress = enable_compress
+ self._org = org
+ self._org_id = None # 需要时通过查询获取,此后不再更新
+ self._token = token
+
+ # influxdb 2.0起支持的时间精度有:ns, us, ms, s。本客户端只支持s, ms和us
+ self._precision = precision.lower()
+ if self._precision not in ["s", "ms", "us"]: # pragma: no cover
+ raise ValueError("precision must be one of ['s', 'ms', 'us']")
+
+ self._chunk_size = chunk_size
+
+ # write
+ self._write_url = f"{self._url}/api/v2/write?org={self._org}&bucket={self._bucket}&precision={self._precision}"
+
+ self._write_headers = {
+ "Content-Type": "text/plain; charset=utf-8",
+ "Authorization": f"Token {token}",
+ "Accept": "application/json",
+ }
+
+ if self._enable_compress:
+ self._write_headers["Content-Encoding"] = "gzip"
+
+ self._query_url = f"{self._url}/api/v2/query?org={self._org}"
+ self._query_headers = {
+ "Authorization": f"Token {token}",
+ "Content-Type": "application/vnd.flux",
+ # influx查询结果格式,无论如何指定(或者不指定),在2.1中始终是csv格式
+ "Accept": "text/csv",
+ }
+
+ if self._enable_compress:
+ self._query_headers["Accept-Encoding"] = "gzip"
+
+ self._delete_url = (
+ f"{self._url}/api/v2/delete?org={self._org}&bucket={self._bucket}"
+ )
+ self._delete_headers = {
+ "Authorization": f"Token {token}",
+ "Content-Type": "application/json",
+ }
+
create_bucket(self, description=None, retention_rules=None, org_id=None)
+
+
+ async
+
+
+¶创建influxdb中指定bucket
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
description |
+ + | 指定bucket的描述 |
+ None |
+
org_id |
+ str |
+ 指定bucket所属的组织id,如果未指定,则使用本client对应的组织id。 |
+ None |
+
Exceptions:
+Type | +Description | +
---|---|
InfluxSchemaError |
+ 当influxdb返回错误时,比如重复创建bucket等,会抛出此异常 |
+
Returns:
+Type | +Description | +
---|---|
str |
+ 新创建的bucket的id |
+
omicron/dal/influx/influxclient.py
async def create_bucket(
+ self, description=None, retention_rules: List[Dict] = None, org_id: str = None
+) -> str:
+ """创建influxdb中指定bucket
+
+ Args:
+ description: 指定bucket的描述
+ org_id: 指定bucket所属的组织id,如果未指定,则使用本client对应的组织id。
+
+ Raises:
+ InfluxSchemaError: 当influxdb返回错误时,比如重复创建bucket等,会抛出此异常
+ Returns:
+ 新创建的bucket的id
+ """
+ if org_id is None:
+ org_id = await self.query_org_id()
+
+ url = f"{self._url}/api/v2/buckets"
+ headers = {"Authorization": f"Token {self._token}"}
+ data = {
+ "name": self._bucket,
+ "orgID": org_id,
+ "description": description,
+ "retentionRules": retention_rules,
+ }
+ async with ClientSession() as session:
+ async with session.post(
+ url, data=json.dumps(data), headers=headers
+ ) as resp:
+ if resp.status != 201:
+ err = await resp.json()
+ logger.warning(
+ "influxdb create bucket error: %s when processin command %s",
+ err["message"],
+ data,
+ )
+ raise InfluxSchemaError(
+ f"influxdb create bucket failed, status code: {err['message']}"
+ )
+ else:
+ result = await resp.json()
+ return result["id"]
+
delete(self, measurement, stop, tags={}, start=None, precision='s')
+
+
+ async
+
+
+¶删除influxdb中指定时间段内的数据
+关于参数,请参见Flux.delete。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
measurement |
+ str |
+ 指定measurement名字 |
+ required | +
stop |
+ datetime |
+ 待删除记录的结束时间 |
+ required | +
start |
+ datetime |
+ 待删除记录的开始时间,如果未指定,则使用EPOCH_START |
+ None |
+
tags |
+ Optional[Dict[str, str]] |
+ 按tag进行过滤的条件 |
+ {} |
+
precision |
+ str |
+ 用以格式化起始和结束时间。 |
+ 's' |
+
Exceptions:
+Type | +Description | +
---|---|
InfluxDeleteError |
+ 如果删除失败,则抛出此异常 |
+
omicron/dal/influx/influxclient.py
async def delete(
+ self,
+ measurement: str,
+ stop: datetime.datetime,
+ tags: Optional[Dict[str, str]] = {},
+ start: datetime.datetime = None,
+ precision: str = "s",
+):
+ """删除influxdb中指定时间段内的数据
+
+ 关于参数,请参见[Flux.delete][omicron.dal.influx.flux.Flux.delete]。
+
+ Args:
+ measurement: 指定measurement名字
+ stop: 待删除记录的结束时间
+ start: 待删除记录的开始时间,如果未指定,则使用EPOCH_START
+ tags: 按tag进行过滤的条件
+ precision: 用以格式化起始和结束时间。
+
+ Raises:
+ InfluxDeleteError: 如果删除失败,则抛出此异常
+ """
+ # todo: add raise error declaration
+ command = Flux().delete(
+ measurement, stop, tags, start=start, precision=precision
+ )
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._delete_url, data=json.dumps(command), headers=self._delete_headers
+ ) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb delete error: %s when processin command %s",
+ err["message"],
+ command,
+ )
+ raise InfluxDeleteError(
+ f"influxdb delete failed, status code: {err['message']}"
+ )
+
delete_bucket(self, bucket_id=None)
+
+
+ async
+
+
+¶删除influxdb中指定bucket
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
bucket_id |
+ str |
+ 指定bucket的id。如果为None,则会删除本client对应的bucket。 |
+ None |
+
omicron/dal/influx/influxclient.py
async def delete_bucket(self, bucket_id: str = None):
+ """删除influxdb中指定bucket
+
+ Args:
+ bucket_id: 指定bucket的id。如果为None,则会删除本client对应的bucket。
+ """
+ if bucket_id is None:
+ buckets = await self.list_buckets()
+ for bucket in buckets:
+ if bucket["type"] == "user" and bucket["name"] == self._bucket:
+ bucket_id = bucket["id"]
+ break
+ else:
+ raise BadParameterError(
+ "bucket_id is None, and we can't find bucket with name: %s"
+ % self._bucket
+ )
+
+ url = f"{self._url}/api/v2/buckets/{bucket_id}"
+ headers = {"Authorization": f"Token {self._token}"}
+ async with ClientSession() as session:
+ async with session.delete(url, headers=headers) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb delete bucket error: %s when processin command %s",
+ err["message"],
+ bucket_id,
+ )
+ raise InfluxSchemaError(
+ f"influxdb delete bucket failed, status code: {err['message']}"
+ )
+
drop_measurement(self, measurement)
+
+
+ async
+
+
+¶从influxdb中删除一个measurement
+调用此方法后,实际上该measurement仍然存在,只是没有数据。
+ +omicron/dal/influx/influxclient.py
async def drop_measurement(self, measurement: str):
+ """从influxdb中删除一个measurement
+
+ 调用此方法后,实际上该measurement仍然存在,只是没有数据。
+
+ """
+ # todo: add raise error declaration
+ await self.delete(measurement, arrow.now().naive)
+
list_buckets(self)
+
+
+ async
+
+
+¶列出influxdb中对应token能看到的所有的bucket
+ +Returns:
+Type | +Description | +
---|---|
list of buckets, each bucket is a dict with keys |
+ ``` +id +orgID, a 16 bytes hex string +type, system or user +description +name +retentionRules +createdAt +updatedAt +links +labels |
+
```
+ +omicron/dal/influx/influxclient.py
async def list_buckets(self) -> List[Dict]:
+ """列出influxdb中对应token能看到的所有的bucket
+
+ Returns:
+ list of buckets, each bucket is a dict with keys:
+ ```
+ id
+ orgID, a 16 bytes hex string
+ type, system or user
+ description
+ name
+ retentionRules
+ createdAt
+ updatedAt
+ links
+ labels
+ ```
+ """
+ url = f"{self._url}/api/v2/buckets"
+ headers = {"Authorization": f"Token {self._token}"}
+ async with ClientSession() as session:
+ async with session.get(url, headers=headers) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ raise InfluxSchemaError(
+ f"influxdb list bucket failed, status code: {err['message']}"
+ )
+ else:
+ return (await resp.json())["buckets"]
+
list_organizations(self, offset=0, limit=100)
+
+
+ async
+
+
+¶列出本客户端允许查询的所组织
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
offset |
+ + | 分页起点 |
+ 0 |
+
limit |
+ + | 每页size |
+ 100 |
+
Exceptions:
+Type | +Description | +
---|---|
InfluxSchemaError |
+ influxdb返回的错误 |
+
Returns:
+Type | +Description | +||
---|---|---|---|
list of organizations, each organization is a dict with keys |
+
|
+
omicron/dal/influx/influxclient.py
async def list_organizations(self, offset: int = 0, limit: int = 100) -> List[Dict]:
+ """列出本客户端允许查询的所组织
+
+ Args:
+ offset : 分页起点
+ limit : 每页size
+
+ Raises:
+ InfluxSchemaError: influxdb返回的错误
+
+ Returns:
+ list of organizations, each organization is a dict with keys:
+ ```
+ id : the id of the org
+ links
+ name : the name of the org
+ description
+ createdAt
+ updatedAt
+ ```
+ """
+ url = f"{self._url}/api/v2/orgs?offset={offset}&limit={limit}"
+ headers = {"Authorization": f"Token {self._token}"}
+
+ async with ClientSession() as session:
+ async with session.get(url, headers=headers) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ logger.warning("influxdb query orgs err: %s", err["message"])
+ raise InfluxSchemaError(
+ f"influxdb query orgs failed, status code: {err['message']}"
+ )
+ else:
+ return (await resp.json())["orgs"]
+
query(self, flux, deserializer=None)
+
+
+ async
+
+
+¶flux查询
+flux查询结果是一个以annotated csv格式存储的数据,例如: +
1 +2 |
|
上述result
中,事先通过Flux.keep()限制了返回的字段为_time,code,amount,close,factor,high,low,open,volume。influxdb查询返回结果时,总是按照字段名称升序排列。此外,总是会额外地返回_result, table两个字段。
如果传入了deserializer,则会调用deserializer将其解析成为python对象。否则,返回bytes数据。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
flux |
+ Union[omicron.dal.influx.flux.Flux, str] |
+ flux查询语句 |
+ required | +
deserializer |
+ Callable |
+ 反序列化函数 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Any |
+ 如果未提供反序列化函数,则返回结果为bytes array(如果指定了compress=True,返回结果为gzip解压缩后的bytes array),否则返回反序列化后的python对象 |
+
omicron/dal/influx/influxclient.py
async def query(self, flux: Union[Flux, str], deserializer: Callable = None) -> Any:
+ """flux查询
+
+ flux查询结果是一个以annotated csv格式存储的数据,例如:
+ ```
+ ,result,table,_time,code,amount,close,factor,high,low,open,volume
+ ,_result,0,2019-01-01T00:00:00Z,000001.XSHE,100000000,5.15,1.23,5.2,5,5.1,1000000
+ ```
+
+ 上述`result`中,事先通过Flux.keep()限制了返回的字段为_time,code,amount,close,factor,high,low,open,volume。influxdb查询返回结果时,总是按照字段名称升序排列。此外,总是会额外地返回_result, table两个字段。
+
+ 如果传入了deserializer,则会调用deserializer将其解析成为python对象。否则,返回bytes数据。
+
+ Args:
+ flux: flux查询语句
+ deserializer: 反序列化函数
+
+ Returns:
+ 如果未提供反序列化函数,则返回结果为bytes array(如果指定了compress=True,返回结果为gzip解压缩后的bytes array),否则返回反序列化后的python对象
+ """
+ if isinstance(flux, Flux):
+ flux = str(flux)
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._query_url, data=flux, headers=self._query_headers
+ ) as resp:
+ if resp.status != 200:
+ err = await resp.json()
+ logger.warning(
+ f"influxdb query error: {err} when processing {flux[:500]}"
+ )
+ logger.debug("data caused error:%s", flux)
+ raise InfluxDBQueryError(
+ f"influxdb query failed, status code: {err['message']}"
+ )
+ else:
+ # auto-unzip
+ body = await resp.read()
+ if deserializer:
+ try:
+ return deserializer(body)
+ except Exception as e:
+ logger.exception(e)
+ logger.warning(
+ "failed to deserialize data: %s, the query is:%s",
+ body,
+ flux[:500],
+ )
+ raise
+ else:
+ return body
+
query_org_id(self, name=None)
+
+
+ async
+
+
+¶通过组织名查找组织id
+只能查的本客户端允许查询的组织。如果name未提供,则使用本客户端创建时传入的组织名。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
name |
+ str |
+ 指定组织名 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
str |
+ 组织id |
+
omicron/dal/influx/influxclient.py
async def query_org_id(self, name: str = None) -> str:
+ """通过组织名查找组织id
+
+ 只能查的本客户端允许查询的组织。如果name未提供,则使用本客户端创建时传入的组织名。
+
+ Args:
+ name: 指定组织名
+
+ Returns:
+ 组织id
+ """
+ if name is None:
+ name = self._org
+ orgs = await self.list_organizations()
+ for org in orgs:
+ if org["name"] == name:
+ return org["id"]
+
+ raise BadParameterError(f"can't find org with name: {name}")
+
save(self, data, measurement=None, tag_keys=[], time_key=None, global_tags={}, chunk_size=None)
+
+
+ async
+
+
+¶save data
into influxdb
if data
is a pandas.DataFrame or numy structured array, it will be converted to line protocol and saved. If data
is str, use write
method instead.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
data |
+ Union[numpy.ndarray, pandas.core.frame.DataFrame] |
+ data to be saved |
+ required | +
measurement |
+ str |
+ the name of measurement |
+ None |
+
tag_keys |
+ List[str] |
+ which columns name will be used as tags |
+ [] |
+
chunk_size |
+ int |
+ number of lines to be saved in one request. if it's -1, then all data will be written in one request. If it's None, then it will be set to |
+ None |
+
Exceptions:
+Type | +Description | +
---|---|
InfluxDBWriteError |
+ if write failed |
+
omicron/dal/influx/influxclient.py
async def save(
+ self,
+ data: Union[np.ndarray, DataFrame],
+ measurement: str = None,
+ tag_keys: List[str] = [],
+ time_key: str = None,
+ global_tags: Dict = {},
+ chunk_size: int = None,
+) -> None:
+ """save `data` into influxdb
+
+ if `data` is a pandas.DataFrame or numy structured array, it will be converted to line protocol and saved. If `data` is str, use `write` method instead.
+
+ Args:
+ data: data to be saved
+ measurement: the name of measurement
+ tag_keys: which columns name will be used as tags
+ chunk_size: number of lines to be saved in one request. if it's -1, then all data will be written in one request. If it's None, then it will be set to `self._chunk_size`
+
+ Raises:
+ InfluxDBWriteError: if write failed
+
+ """
+ # todo: add more errors raise
+ if isinstance(data, DataFrame):
+ assert (
+ measurement is not None
+ ), "measurement must be specified when data is a DataFrame"
+
+ if tag_keys:
+ assert set(tag_keys) in set(
+ data.columns.tolist()
+ ), "tag_keys must be in data.columns"
+
+ serializer = DataframeSerializer(
+ data,
+ measurement,
+ time_key,
+ tag_keys,
+ global_tags,
+ precision=self._precision,
+ )
+ if chunk_size == -1:
+ chunk_size = len(data)
+
+ for lines in serializer.serialize(chunk_size or self._chunk_size):
+ await self.write(lines)
+ elif isinstance(data, np.ndarray):
+ assert (
+ measurement is not None
+ ), "measurement must be specified when data is a numpy array"
+ assert (
+ time_key is not None
+ ), "time_key must be specified when data is a numpy array"
+ serializer = NumpySerializer(
+ data,
+ measurement,
+ time_key,
+ tag_keys,
+ global_tags,
+ time_precision=self._precision,
+ )
+ if chunk_size == -1:
+ chunk_size = len(data)
+ for lines in serializer.serialize(chunk_size or self._chunk_size):
+ await self.write(lines)
+ else:
+ raise TypeError(
+ f"data must be pandas.DataFrame, numpy array, got {type(data)}"
+ )
+
write(self, line_protocol)
+
+
+ async
+
+
+¶将line-protocol数组写入influxdb
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
line_protocol |
+ str |
+ 待写入的数据,以line-protocol数组形式存在 |
+ required | +
omicron/dal/influx/influxclient.py
async def write(self, line_protocol: str):
+ """将line-protocol数组写入influxdb
+
+ Args:
+ line_protocol: 待写入的数据,以line-protocol数组形式存在
+
+ """
+ # todo: add raise error declaration
+ if self._enable_compress:
+ line_protocol_ = gzip.compress(line_protocol.encode("utf-8"))
+ else:
+ line_protocol_ = line_protocol
+
+ async with ClientSession() as session:
+ async with session.post(
+ self._write_url, data=line_protocol_, headers=self._write_headers
+ ) as resp:
+ if resp.status != 204:
+ err = await resp.json()
+ logger.warning(
+ "influxdb write error when processing: %s, err code: %s, message: %s",
+ {line_protocol[:100]},
+ err["code"],
+ err["message"],
+ )
+ logger.debug("data caused error:%s", line_protocol)
+ raise InfluxDBWriteError(
+ f"influxdb write failed, err: {err['message']}"
+ )
+
omicron/dal/influx/serialize.py
class DataframeDeserializer(Serializer):
+ def __init__(
+ self,
+ sort_values: Union[str, List[str]] = None,
+ encoding: str = "utf-8",
+ names: List[str] = None,
+ usecols: Union[List[int], List[str]] = None,
+ dtype: dict = None,
+ time_col: Union[int, str] = None,
+ sep: str = ",",
+ header: Union[int, List[int], str] = "infer",
+ engine: str = None,
+ infer_datetime_format=True,
+ lineterminator: str = None,
+ converters: dict = None,
+ skipfooter=0,
+ index_col: Union[int, str, List[int], List[str], bool] = None,
+ skiprows: Union[int, List[int], Callable] = None,
+ **kwargs,
+ ):
+ """constructor a deserializer which convert a csv-like bytes array to pandas.DataFrame
+
+ the args are the same as pandas.read_csv. for details, please refer to the official doc: [pandas.read_csv](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html)
+
+ for performance consideration, please specify the following args:
+ - engine = 'c' or 'pyarrow' when possible. Be noticed that 'pyarrow' is the fastest (multi-threaded supported) but may be error-prone. Only use it when you have thoroughly tested.
+
+ - specify dtype when possible
+
+ use `usecols` to specify the columns to read, and `names` to specify the column names (i.e., rename the columns), otherwise, the column names will be inferred from the first line.
+
+ when `names` is specified, it has to be as same length as actual columns of the data. If this causes column renaming, then you should always use column name specified in `names` to access the data (instead of which in `usecols`).
+
+ Examples:
+ >>> data = ",result,table,_time,code,name\\r\\n,_result,0,2019-01-01T09:31:00Z,000002.XSHE,国联证券"
+ >>> des = DataframeDeserializer(names=["_", "result", "table", "frame", "code", "name"], usecols=["frame", "code", "name"])
+ >>> des(data)
+ frame code name
+ 0 2019-01-01T09:31:00Z 000002.XSHE 国联证券
+
+ Args:
+ sort_values: sort the dataframe by the specified columns
+ encoding: if the data is bytes, then encoding is required, due to pandas.read_csv only handle string array
+ sep: the separator/delimiter of each fields
+ header: the row number of the header, default is 'infer'
+ names: the column names of the dataframe
+ index_col: the column number or name of the index column
+ usecols: the column name of the columns to use
+ dtype: the dtype of the columns
+ engine: the engine of the csv file, default is None
+ converters: specify converter for columns.
+ skiprows: the row number to skip
+ skipfooter: the row number to skip at the end of the file
+ time_col: the columns to parse as dates
+ infer_datetime_format: whether to infer the datetime format
+ lineterminator: the line terminator of the csv file, only valid when engine is 'c'
+ kwargs: other arguments
+ """
+ self.sort_values = sort_values
+ self.encoding = encoding
+ self.sep = sep
+ self.header = header
+ self.names = names
+ self.index_col = index_col
+ self.usecols = usecols
+ self.dtype = dtype
+ self.engine = engine
+ self.converters = converters or {}
+ self.skiprows = skiprows
+ self.skipfooter = skipfooter
+ self.infer_datetime_format = infer_datetime_format
+
+ self.lineterminator = lineterminator
+ self.kwargs = kwargs
+
+ if names is not None:
+ self.header = 0
+
+ if time_col is not None:
+ self.converters[time_col] = lambda x: ciso8601.parse_datetime_as_naive(x)
+
+ def __call__(self, data: Union[str, bytes]) -> pd.DataFrame:
+ if isinstance(data, str):
+ # treat data as string
+ stream = io.StringIO(data)
+ else:
+ stream = io.StringIO(data.decode(self.encoding))
+
+ df = pd.read_csv(
+ stream,
+ sep=self.sep,
+ header=self.header,
+ names=self.names,
+ index_col=self.index_col,
+ usecols=self.usecols,
+ dtype=self.dtype,
+ engine=self.engine,
+ converters=self.converters,
+ skiprows=self.skiprows,
+ skipfooter=self.skipfooter,
+ infer_datetime_format=self.infer_datetime_format,
+ lineterminator=self.lineterminator,
+ **self.kwargs,
+ )
+
+ if self.usecols:
+ df = df[list(self.usecols)]
+ if self.sort_values is not None:
+ return df.sort_values(self.sort_values)
+ else:
+ return df
+
__init__(self, sort_values=None, encoding='utf-8', names=None, usecols=None, dtype=None, time_col=None, sep=',', header='infer', engine=None, infer_datetime_format=True, lineterminator=None, converters=None, skipfooter=0, index_col=None, skiprows=None, **kwargs)
+
+
+ special
+
+
+¶constructor a deserializer which convert a csv-like bytes array to pandas.DataFrame
+the args are the same as pandas.read_csv. for details, please refer to the official doc: pandas.read_csv
+for performance consideration, please specify the following args: + - engine = 'c' or 'pyarrow' when possible. Be noticed that 'pyarrow' is the fastest (multi-threaded supported) but may be error-prone. Only use it when you have thoroughly tested.
+1 |
|
use usecols
to specify the columns to read, and names
to specify the column names (i.e., rename the columns), otherwise, the column names will be inferred from the first line.
when names
is specified, it has to be as same length as actual columns of the data. If this causes column renaming, then you should always use column name specified in names
to access the data (instead of which in usecols
).
Examples:
+>>> data = ",result,table,_time,code,name\r\n,_result,0,2019-01-01T09:31:00Z,000002.XSHE,国联证券"
+>>> des = DataframeDeserializer(names=["_", "result", "table", "frame", "code", "name"], usecols=["frame", "code", "name"])
+>>> des(data)
+ frame code name
+0 2019-01-01T09:31:00Z 000002.XSHE 国联证券
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sort_values |
+ Union[str, List[str]] |
+ sort the dataframe by the specified columns |
+ None |
+
encoding |
+ str |
+ if the data is bytes, then encoding is required, due to pandas.read_csv only handle string array |
+ 'utf-8' |
+
sep |
+ str |
+ the separator/delimiter of each fields |
+ ',' |
+
header |
+ Union[int, List[int], str] |
+ the row number of the header, default is 'infer' |
+ 'infer' |
+
names |
+ List[str] |
+ the column names of the dataframe |
+ None |
+
index_col |
+ Union[int, str, List[int], List[str], bool] |
+ the column number or name of the index column |
+ None |
+
usecols |
+ Union[List[int], List[str]] |
+ the column name of the columns to use |
+ None |
+
dtype |
+ dict |
+ the dtype of the columns |
+ None |
+
engine |
+ str |
+ the engine of the csv file, default is None |
+ None |
+
converters |
+ dict |
+ specify converter for columns. |
+ None |
+
skiprows |
+ Union[int, List[int], Callable] |
+ the row number to skip |
+ None |
+
skipfooter |
+ + | the row number to skip at the end of the file |
+ 0 |
+
time_col |
+ Union[int, str] |
+ the columns to parse as dates |
+ None |
+
infer_datetime_format |
+ + | whether to infer the datetime format |
+ True |
+
lineterminator |
+ str |
+ the line terminator of the csv file, only valid when engine is 'c' |
+ None |
+
kwargs |
+ + | other arguments |
+ {} |
+
omicron/dal/influx/serialize.py
def __init__(
+ self,
+ sort_values: Union[str, List[str]] = None,
+ encoding: str = "utf-8",
+ names: List[str] = None,
+ usecols: Union[List[int], List[str]] = None,
+ dtype: dict = None,
+ time_col: Union[int, str] = None,
+ sep: str = ",",
+ header: Union[int, List[int], str] = "infer",
+ engine: str = None,
+ infer_datetime_format=True,
+ lineterminator: str = None,
+ converters: dict = None,
+ skipfooter=0,
+ index_col: Union[int, str, List[int], List[str], bool] = None,
+ skiprows: Union[int, List[int], Callable] = None,
+ **kwargs,
+):
+ """constructor a deserializer which convert a csv-like bytes array to pandas.DataFrame
+
+ the args are the same as pandas.read_csv. for details, please refer to the official doc: [pandas.read_csv](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html)
+
+ for performance consideration, please specify the following args:
+ - engine = 'c' or 'pyarrow' when possible. Be noticed that 'pyarrow' is the fastest (multi-threaded supported) but may be error-prone. Only use it when you have thoroughly tested.
+
+ - specify dtype when possible
+
+ use `usecols` to specify the columns to read, and `names` to specify the column names (i.e., rename the columns), otherwise, the column names will be inferred from the first line.
+
+ when `names` is specified, it has to be as same length as actual columns of the data. If this causes column renaming, then you should always use column name specified in `names` to access the data (instead of which in `usecols`).
+
+ Examples:
+ >>> data = ",result,table,_time,code,name\\r\\n,_result,0,2019-01-01T09:31:00Z,000002.XSHE,国联证券"
+ >>> des = DataframeDeserializer(names=["_", "result", "table", "frame", "code", "name"], usecols=["frame", "code", "name"])
+ >>> des(data)
+ frame code name
+ 0 2019-01-01T09:31:00Z 000002.XSHE 国联证券
+
+ Args:
+ sort_values: sort the dataframe by the specified columns
+ encoding: if the data is bytes, then encoding is required, due to pandas.read_csv only handle string array
+ sep: the separator/delimiter of each fields
+ header: the row number of the header, default is 'infer'
+ names: the column names of the dataframe
+ index_col: the column number or name of the index column
+ usecols: the column name of the columns to use
+ dtype: the dtype of the columns
+ engine: the engine of the csv file, default is None
+ converters: specify converter for columns.
+ skiprows: the row number to skip
+ skipfooter: the row number to skip at the end of the file
+ time_col: the columns to parse as dates
+ infer_datetime_format: whether to infer the datetime format
+ lineterminator: the line terminator of the csv file, only valid when engine is 'c'
+ kwargs: other arguments
+ """
+ self.sort_values = sort_values
+ self.encoding = encoding
+ self.sep = sep
+ self.header = header
+ self.names = names
+ self.index_col = index_col
+ self.usecols = usecols
+ self.dtype = dtype
+ self.engine = engine
+ self.converters = converters or {}
+ self.skiprows = skiprows
+ self.skipfooter = skipfooter
+ self.infer_datetime_format = infer_datetime_format
+
+ self.lineterminator = lineterminator
+ self.kwargs = kwargs
+
+ if names is not None:
+ self.header = 0
+
+ if time_col is not None:
+ self.converters[time_col] = lambda x: ciso8601.parse_datetime_as_naive(x)
+
omicron/dal/influx/serialize.py
class NumpyDeserializer(Serializer):
+ def __init__(
+ self,
+ dtype: List[tuple] = "float",
+ sort_values: Union[str, List[str]] = None,
+ use_cols: Union[List[str], List[int]] = None,
+ parse_date: Union[int, str] = "_time",
+ sep: str = ",",
+ encoding: str = "utf-8",
+ skip_rows: Union[int, List[int]] = 1,
+ header_line: int = 1,
+ comments: str = "#",
+ converters: Mapping[int, Callable] = None,
+ ):
+ """construct a deserializer, which will convert a csv like multiline string/bytes array to a numpy array
+
+ the data to be deserialized will be first split into array of fields, then use use_cols to select which fields to use, and re-order them by the order of use_cols. After that, the fields will be converted to numpy array and converted into dtype.
+
+ by default dtype is float, which means the data will be converted to float. If you need to convert to a numpy structured array, then you can specify the dtype as a list of tuples, e.g.
+
+ ```
+ dtype = [('col_1', 'datetime64[s]'), ('col_2', '<U12'), ('col_3', '<U4')]
+
+ ```
+
+ by default, the deserializer will try to convert every line from the very first line, if the very first lines contains comments and headers, these lines should be skipped by deserializer, you should set skip_rows to number of lines to skip.
+
+ for more information, please refer to [numpy.loadtxt](https://numpy.org/doc/stable/reference/generated/numpy.loadtxt.html)
+
+ Args:
+ dtype: dtype of the output numpy array.
+ sort_values: sort the output numpy array by the specified columns. If it's a string, then it's the name of the column, if it's a list of strings, then it's the names of the columns.
+ use_cols: use only the specified columns. If it's a list of strings, then it's the names of the columns (presented in raw data header line), if it's a list of integers, then it's the column index.
+ parse_date: by default we'll convert "_time" column into python datetime.datetime. Set it to None to turn off the conversion. ciso8601 is default parser. If you need to parse date but just don't like ciso8601, then you can turn off default parser (by set parse_date to None), and specify your own parser in converters.
+ sep: separator of each field
+ encoding: if the input is bytes, then encoding is used to decode the bytes to string.
+ skip_rows: required by np.loadtxt, skip the first n lines
+ header_line: which line contains header, started from 1. If you specify use_cols by list of string, then header line must be specified.
+ comments: required by np.loadtxt, skip the lines starting with this string
+ converters: required by np.loadtxt, a dict of column name to converter function.
+
+ """
+ self.dtype = dtype
+ self.use_cols = use_cols
+ self.sep = sep
+ self.encoding = encoding
+ self.skip_rows = skip_rows
+ self.comments = comments
+ self.converters = converters or {}
+ self.sort_values = sort_values
+ self.parse_date = parse_date
+ self.header_line = header_line
+
+ if header_line is None:
+ assert parse_date is None or isinstance(
+ parse_date, int
+ ), "parse_date must be an integer if data contains no header"
+
+ assert use_cols is None or isinstance(
+ use_cols[0], int
+ ), "use_cols must be a list of integers if data contains no header"
+
+ if len(self.converters) > 1:
+ assert all(
+ [isinstance(x, int) for x in self.converters.keys()]
+ ), "converters must be a dict of column index to converter function, if there's no header"
+
+ self._parsed_headers = None
+
+ def _parse_header_once(self, stream):
+ """parse header and convert use_cols, if columns is specified in string. And if parse_date is required, add it into converters
+
+ Args:
+ stream : [description]
+
+ Raises:
+ SerializationError: [description]
+ """
+ if self.header_line is None or self._parsed_headers is not None:
+ return
+
+ try:
+ line = stream.readlines(self.header_line)[-1]
+ cols = line.strip().split(self.sep)
+ self._parsed_headers = cols
+
+ use_cols = self.use_cols
+ if use_cols is not None and isinstance(use_cols[0], str):
+ self.use_cols = [cols.index(col) for col in self.use_cols]
+
+ # convert keys of converters to int
+ converters = {cols.index(k): v for k, v in self.converters.items()}
+
+ self.converters = converters
+
+ if isinstance(self.parse_date, str):
+ parse_date = cols.index(self.parse_date)
+ if parse_date in self.converters.keys():
+ logger.debug(
+ "specify duplicated converter in both parse_date and converters for col %s, use converters.",
+ self.parse_date,
+ )
+ else: # 增加parse_date到converters
+ self.converters[
+ parse_date
+ ] = lambda x: ciso8601.parse_datetime_as_naive(x)
+
+ stream.seek(0)
+ except (IndexError, ValueError):
+ if line.strip() == "":
+ content = "".join(stream.readlines()).strip()
+ if len(content) > 0:
+ raise SerializationError(
+ f"specified heder line {self.header_line} is empty"
+ )
+ else:
+ raise EmptyResult()
+ else:
+ raise SerializationError(f"bad header[{self.header_line}]: {line}")
+
+ def __call__(self, data: bytes) -> np.ndarray:
+ if self.encoding and isinstance(data, bytes):
+ stream = io.StringIO(data.decode(self.encoding))
+ else:
+ stream = io.StringIO(data)
+
+ try:
+ self._parse_header_once(stream)
+ except EmptyResult:
+ return np.empty((0,), dtype=self.dtype)
+
+ arr = np.loadtxt(
+ stream.readlines(),
+ delimiter=self.sep,
+ skiprows=self.skip_rows,
+ dtype=self.dtype,
+ usecols=self.use_cols,
+ converters=self.converters,
+ encoding=self.encoding,
+ )
+
+ # 如果返回仅一条记录,有时会出现 shape == ()
+ if arr.shape == tuple():
+ arr = arr.reshape((-1,))
+ if self.sort_values is not None and arr.size > 1:
+ return np.sort(arr, order=self.sort_values)
+ else:
+ return arr
+
__init__(self, dtype='float', sort_values=None, use_cols=None, parse_date='_time', sep=',', encoding='utf-8', skip_rows=1, header_line=1, comments='#', converters=None)
+
+
+ special
+
+
+¶construct a deserializer, which will convert a csv like multiline string/bytes array to a numpy array
+the data to be deserialized will be first split into array of fields, then use use_cols to select which fields to use, and re-order them by the order of use_cols. After that, the fields will be converted to numpy array and converted into dtype.
+by default dtype is float, which means the data will be converted to float. If you need to convert to a numpy structured array, then you can specify the dtype as a list of tuples, e.g.
+1 |
|
by default, the deserializer will try to convert every line from the very first line, if the very first lines contains comments and headers, these lines should be skipped by deserializer, you should set skip_rows to number of lines to skip.
+for more information, please refer to numpy.loadtxt
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
dtype |
+ List[tuple] |
+ dtype of the output numpy array. |
+ 'float' |
+
sort_values |
+ Union[str, List[str]] |
+ sort the output numpy array by the specified columns. If it's a string, then it's the name of the column, if it's a list of strings, then it's the names of the columns. |
+ None |
+
use_cols |
+ Union[List[str], List[int]] |
+ use only the specified columns. If it's a list of strings, then it's the names of the columns (presented in raw data header line), if it's a list of integers, then it's the column index. |
+ None |
+
parse_date |
+ Union[int, str] |
+ by default we'll convert "_time" column into python datetime.datetime. Set it to None to turn off the conversion. ciso8601 is default parser. If you need to parse date but just don't like ciso8601, then you can turn off default parser (by set parse_date to None), and specify your own parser in converters. |
+ '_time' |
+
sep |
+ str |
+ separator of each field |
+ ',' |
+
encoding |
+ str |
+ if the input is bytes, then encoding is used to decode the bytes to string. |
+ 'utf-8' |
+
skip_rows |
+ Union[int, List[int]] |
+ required by np.loadtxt, skip the first n lines |
+ 1 |
+
header_line |
+ int |
+ which line contains header, started from 1. If you specify use_cols by list of string, then header line must be specified. |
+ 1 |
+
comments |
+ str |
+ required by np.loadtxt, skip the lines starting with this string |
+ '#' |
+
converters |
+ Mapping[int, Callable] |
+ required by np.loadtxt, a dict of column name to converter function. |
+ None |
+
omicron/dal/influx/serialize.py
def __init__(
+ self,
+ dtype: List[tuple] = "float",
+ sort_values: Union[str, List[str]] = None,
+ use_cols: Union[List[str], List[int]] = None,
+ parse_date: Union[int, str] = "_time",
+ sep: str = ",",
+ encoding: str = "utf-8",
+ skip_rows: Union[int, List[int]] = 1,
+ header_line: int = 1,
+ comments: str = "#",
+ converters: Mapping[int, Callable] = None,
+):
+ """construct a deserializer, which will convert a csv like multiline string/bytes array to a numpy array
+
+ the data to be deserialized will be first split into array of fields, then use use_cols to select which fields to use, and re-order them by the order of use_cols. After that, the fields will be converted to numpy array and converted into dtype.
+
+ by default dtype is float, which means the data will be converted to float. If you need to convert to a numpy structured array, then you can specify the dtype as a list of tuples, e.g.
+
+ ```
+ dtype = [('col_1', 'datetime64[s]'), ('col_2', '<U12'), ('col_3', '<U4')]
+
+ ```
+
+ by default, the deserializer will try to convert every line from the very first line, if the very first lines contains comments and headers, these lines should be skipped by deserializer, you should set skip_rows to number of lines to skip.
+
+ for more information, please refer to [numpy.loadtxt](https://numpy.org/doc/stable/reference/generated/numpy.loadtxt.html)
+
+ Args:
+ dtype: dtype of the output numpy array.
+ sort_values: sort the output numpy array by the specified columns. If it's a string, then it's the name of the column, if it's a list of strings, then it's the names of the columns.
+ use_cols: use only the specified columns. If it's a list of strings, then it's the names of the columns (presented in raw data header line), if it's a list of integers, then it's the column index.
+ parse_date: by default we'll convert "_time" column into python datetime.datetime. Set it to None to turn off the conversion. ciso8601 is default parser. If you need to parse date but just don't like ciso8601, then you can turn off default parser (by set parse_date to None), and specify your own parser in converters.
+ sep: separator of each field
+ encoding: if the input is bytes, then encoding is used to decode the bytes to string.
+ skip_rows: required by np.loadtxt, skip the first n lines
+ header_line: which line contains header, started from 1. If you specify use_cols by list of string, then header line must be specified.
+ comments: required by np.loadtxt, skip the lines starting with this string
+ converters: required by np.loadtxt, a dict of column name to converter function.
+
+ """
+ self.dtype = dtype
+ self.use_cols = use_cols
+ self.sep = sep
+ self.encoding = encoding
+ self.skip_rows = skip_rows
+ self.comments = comments
+ self.converters = converters or {}
+ self.sort_values = sort_values
+ self.parse_date = parse_date
+ self.header_line = header_line
+
+ if header_line is None:
+ assert parse_date is None or isinstance(
+ parse_date, int
+ ), "parse_date must be an integer if data contains no header"
+
+ assert use_cols is None or isinstance(
+ use_cols[0], int
+ ), "use_cols must be a list of integers if data contains no header"
+
+ if len(self.converters) > 1:
+ assert all(
+ [isinstance(x, int) for x in self.converters.keys()]
+ ), "converters must be a dict of column index to converter function, if there's no header"
+
+ self._parsed_headers = None
+
decimals
+
+
+
+¶math_round(x, digits)
+
+
+¶由于浮点数的表示问题,很多语言的round函数与数学上的round函数不一致。下面的函数结果与数学上的一致。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ float |
+ 要进行四舍五入的数字 |
+ required | +
digits |
+ int |
+ 小数点后保留的位数 |
+ required | +
omicron/extensions/decimals.py
def math_round(x: float, digits: int):
+ """由于浮点数的表示问题,很多语言的round函数与数学上的round函数不一致。下面的函数结果与数学上的一致。
+
+ Args:
+ x: 要进行四舍五入的数字
+ digits: 小数点后保留的位数
+
+ """
+
+ return int(x * (10**digits) + copysign(0.5, x)) / (10**digits)
+
price_equal(x, y)
+
+
+¶判断股价是否相等
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ + | 价格1 |
+ required | +
y |
+ + | 价格2 |
+ required | +
Returns:
+Type | +Description | +
---|---|
bool |
+ 如果相等则返回True,否则返回False |
+
omicron/extensions/decimals.py
def price_equal(x: float, y: float) -> bool:
+ """判断股价是否相等
+
+ Args:
+ x : 价格1
+ y : 价格2
+
+ Returns:
+ 如果相等则返回True,否则返回False
+ """
+ return abs(math_round(x, 2) - math_round(y, 2)) < 1e-2
+
np
+
+
+
+¶Extension function related to numpy
+ + + +array_math_round(arr, digits)
+
+
+¶将一维数组arr的数据进行四舍五入
+numpy.around的函数并不是数学上的四舍五入,对1.5和2.5进行round的结果都会变成2,在金融领域计算中,我们必须使用数学意义上的四舍五入。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ ArrayLike |
+ 输入数组 |
+ required | +
digits |
+ int |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ 四舍五入后的一维数组 |
+
omicron/extensions/np.py
def array_math_round(arr: Union[float, ArrayLike], digits: int) -> np.ndarray:
+ """将一维数组arr的数据进行四舍五入
+
+ numpy.around的函数并不是数学上的四舍五入,对1.5和2.5进行round的结果都会变成2,在金融领域计算中,我们必须使用数学意义上的四舍五入。
+
+ Args:
+ arr (ArrayLike): 输入数组
+ digits (int):
+
+ Returns:
+ np.ndarray: 四舍五入后的一维数组
+ """
+ # 如果是单个元素,则直接返回
+ if isinstance(arr, float):
+ return decimals.math_round(arr, digits)
+
+ f = np.vectorize(lambda x: decimals.math_round(x, digits))
+ return f(arr)
+
array_price_equal(price1, price2)
+
+
+¶判断两个价格数组是否相等
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
price1 |
+ ArrayLike |
+ 价格数组 |
+ required | +
price2 |
+ ArrayLike |
+ 价格数组 |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ 判断结果 |
+
omicron/extensions/np.py
def array_price_equal(price1: ArrayLike, price2: ArrayLike) -> np.ndarray:
+ """判断两个价格数组是否相等
+
+ Args:
+ price1 (ArrayLike): 价格数组
+ price2 (ArrayLike): 价格数组
+
+ Returns:
+ np.ndarray: 判断结果
+ """
+ price1 = array_math_round(price1, 2)
+ price2 = array_math_round(price2, 2)
+
+ return abs(price1 - price2) < 1e-2
+
bars_since(condition, default=None)
+
+
+¶Return the number of bars since condition
sequence was last True
,
+or if never, return default
.
1 +2 +3 |
|
omicron/extensions/np.py
def bars_since(condition: Sequence[bool], default=None) -> int:
+ """
+ Return the number of bars since `condition` sequence was last `True`,
+ or if never, return `default`.
+
+ >>> condition = [True, True, False]
+ >>> bars_since(condition)
+ 1
+ """
+ return next(compress(range(len(condition)), reversed(condition)), default)
+
bin_cut(arr, n)
+
+
+¶将数组arr切分成n份
+todo: use padding + reshape to boost performance
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ [type] |
+ [description] |
+ required | +
n |
+ [type] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
[type] |
+ [description] |
+
omicron/extensions/np.py
def bin_cut(arr: list, n: int):
+ """将数组arr切分成n份
+
+ todo: use padding + reshape to boost performance
+ Args:
+ arr ([type]): [description]
+ n ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ result = [[] for i in range(n)]
+
+ for i, e in enumerate(arr):
+ result[i % n].append(e)
+
+ return [e for e in result if len(e)]
+
count_between(arr, start, end)
+
+
+¶计算数组中,start
元素与end
元素之间共有多少个元素
要求arr必须是已排序。计算结果会包含区间边界点。
+ +Examples:
+>>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+>>> count_between(arr, 20050104, 20050111)
+6
+
>>> count_between(arr, 20050104, 20050109)
+4
+
omicron/extensions/np.py
def count_between(arr, start, end):
+ """计算数组中,`start`元素与`end`元素之间共有多少个元素
+
+ 要求arr必须是已排序。计算结果会包含区间边界点。
+
+ Examples:
+ >>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+ >>> count_between(arr, 20050104, 20050111)
+ 6
+
+ >>> count_between(arr, 20050104, 20050109)
+ 4
+ """
+ pos_start = np.searchsorted(arr, start, side="right")
+ pos_end = np.searchsorted(arr, end, side="right")
+
+ counter = pos_end - pos_start + 1
+ if start < arr[0]:
+ counter -= 1
+ if end > arr[-1]:
+ counter -= 1
+
+ return counter
+
dataframe_to_structured_array(df, dtypes=None)
+
+
+¶convert dataframe (with all columns, and index possibly) to numpy structured arrays
+len(dtypes)
should be either equal to len(df.columns)
or len(df.columns) + 1
. In the later case, it implies to include df.index
into converted array.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
df |
+ DataFrame |
+ the one needs to be converted |
+ required | +
dtypes |
+ List[Tuple] |
+ Defaults to None. If it's |
+ None |
+
Returns:
+Type | +Description | +
---|---|
ArrayLike |
+ [description] |
+
omicron/extensions/np.py
def dataframe_to_structured_array(
+ df: DataFrame, dtypes: List[Tuple] = None
+) -> ArrayLike:
+ """convert dataframe (with all columns, and index possibly) to numpy structured arrays
+
+ `len(dtypes)` should be either equal to `len(df.columns)` or `len(df.columns) + 1`. In the later case, it implies to include `df.index` into converted array.
+
+ Args:
+ df: the one needs to be converted
+ dtypes: Defaults to None. If it's `None`, then dtypes of `df` is used, in such case, the `index` of `df` will not be converted.
+
+ Returns:
+ ArrayLike: [description]
+ """
+ v = df
+ if dtypes is not None:
+ dtypes_in_dict = {key: value for key, value in dtypes}
+
+ col_len = len(df.columns)
+ if len(dtypes) == col_len + 1:
+ v = df.reset_index()
+
+ rename_index_to = set(dtypes_in_dict.keys()).difference(set(df.columns))
+ v.rename(columns={"index": list(rename_index_to)[0]}, inplace=True)
+ elif col_len != len(dtypes):
+ raise ValueError(
+ f"length of dtypes should be either {col_len} or {col_len + 1}, is {len(dtypes)}"
+ )
+
+ # re-arrange order of dtypes, in order to align with df.columns
+ dtypes = []
+ for name in v.columns:
+ dtypes.append((name, dtypes_in_dict[name]))
+ else:
+ dtypes = df.dtypes
+
+ return np.array(np.rec.fromrecords(v.values), dtype=dtypes)
+
dict_to_numpy_array(d, dtype)
+
+
+¶convert dictionary to numpy array
+ +Examples:
+ +++ +++++d = {"aaron": 5, "jack": 6} +dtype = [("name", "S8"), ("score", "<i4")] +dict_to_numpy_array(d, dtype) +array([(b'aaron', 5), (b'jack', 6)], + dtype=[('name', 'S8'), ('score', '<i4')])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
d |
+ dict |
+ [description] |
+ required | +
dtype |
+ List[Tuple] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def dict_to_numpy_array(d: dict, dtype: List[Tuple]) -> np.array:
+ """convert dictionary to numpy array
+
+ Examples:
+
+ >>> d = {"aaron": 5, "jack": 6}
+ >>> dtype = [("name", "S8"), ("score", "<i4")]
+ >>> dict_to_numpy_array(d, dtype)
+ array([(b'aaron', 5), (b'jack', 6)],
+ dtype=[('name', 'S8'), ('score', '<i4')])
+
+ Args:
+ d (dict): [description]
+ dtype (List[Tuple]): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ return np.fromiter(d.items(), dtype=dtype, count=len(d))
+
fill_nan(ts)
+
+
+¶将ts中的NaN替换为其前值
+如果ts起头的元素为NaN,则用第一个非NaN元素替换。
+如果所有元素都为NaN,则无法替换。
+ +Examples:
+>>> arr = np.arange(6, dtype=np.float32)
+>>> arr[3:5] = np.NaN
+>>> fill_nan(arr)
+...
+array([0., 1., 2., 2., 2., 5.], dtype=float32)
+
>>> arr = np.arange(6, dtype=np.float32)
+>>> arr[0:2] = np.nan
+>>> fill_nan(arr)
+...
+array([2., 2., 2., 3., 4., 5.], dtype=float32)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
omicron/extensions/np.py
def fill_nan(ts: np.ndarray):
+ """将ts中的NaN替换为其前值
+
+ 如果ts起头的元素为NaN,则用第一个非NaN元素替换。
+
+ 如果所有元素都为NaN,则无法替换。
+
+ Example:
+ >>> arr = np.arange(6, dtype=np.float32)
+ >>> arr[3:5] = np.NaN
+ >>> fill_nan(arr)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ array([0., 1., 2., 2., 2., 5.], dtype=float32)
+
+ >>> arr = np.arange(6, dtype=np.float32)
+ >>> arr[0:2] = np.nan
+ >>> fill_nan(arr)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ array([2., 2., 2., 3., 4., 5.], dtype=float32)
+
+ Args:
+ ts (np.array): [description]
+ """
+ if np.all(np.isnan(ts)):
+ raise ValueError("all of ts are NaN")
+
+ if ts[0] is None or math.isnan(ts[0]):
+ idx = np.argwhere(~np.isnan(ts))[0]
+ ts[0] = ts[idx]
+
+ mask = np.isnan(ts)
+ idx = np.where(~mask, np.arange(mask.size), 0)
+ np.maximum.accumulate(idx, out=idx)
+ return ts[idx]
+
find_runs(x)
+
+
+¶Find runs of consecutive items in an array.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ ArrayLike |
+ the sequence to find runs in |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple[np.ndarray, np.ndarray, np.ndarray] |
+ A tuple of unique values, start indices, and length of runs |
+
omicron/extensions/np.py
def find_runs(x: ArrayLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Find runs of consecutive items in an array.
+
+ Args:
+ x: the sequence to find runs in
+
+ Returns:
+ A tuple of unique values, start indices, and length of runs
+ """
+
+ # ensure array
+ x = np.asanyarray(x)
+ if x.ndim != 1:
+ raise ValueError("only 1D array supported")
+ n = x.shape[0]
+
+ # handle empty array
+ if n == 0:
+ return np.array([]), np.array([]), np.array([])
+
+ else:
+ # find run starts
+ loc_run_start = np.empty(n, dtype=bool)
+ loc_run_start[0] = True
+ np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
+ run_starts = np.nonzero(loc_run_start)[0]
+
+ # find run values
+ run_values = x[loc_run_start]
+
+ # find run lengths
+ run_lengths = np.diff(np.append(run_starts, n))
+
+ return run_values, run_starts, run_lengths
+
floor(arr, item)
+
+
+¶在数据arr中,找到小于等于item的那一个值。如果item小于所有arr元素的值,返回arr[0];如果item +大于所有arr元素的值,返回arr[-1]
+与minute_frames_floor
不同的是,本函数不做回绕与进位.
Examples:
+>>> a = [3, 6, 9]
+>>> floor(a, -1)
+3
+>>> floor(a, 9)
+9
+>>> floor(a, 10)
+9
+>>> floor(a, 4)
+3
+>>> floor(a,10)
+9
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ + | + | required | +
item |
+ + | + | required | +
omicron/extensions/np.py
def floor(arr, item):
+ """
+ 在数据arr中,找到小于等于item的那一个值。如果item小于所有arr元素的值,返回arr[0];如果item
+ 大于所有arr元素的值,返回arr[-1]
+
+ 与`minute_frames_floor`不同的是,本函数不做回绕与进位.
+
+ Examples:
+ >>> a = [3, 6, 9]
+ >>> floor(a, -1)
+ 3
+ >>> floor(a, 9)
+ 9
+ >>> floor(a, 10)
+ 9
+ >>> floor(a, 4)
+ 3
+ >>> floor(a,10)
+ 9
+
+ Args:
+ arr:
+ item:
+
+ Returns:
+
+ """
+ if item < arr[0]:
+ return arr[0]
+ index = np.searchsorted(arr, item, side="right")
+ return arr[index - 1]
+
join_by_left(key, r1, r2, mask=True)
+
+
+¶左连接 r1
, r2
by key
如果r1
中存在r2
中没有的行,则该行对应的r2
中的那些字段将被mask,或者填充随机数。
+same as numpy.lib.recfunctions.join_by(key, r1, r2, jointype='leftouter'), but allows r1 have duplicate keys
Examples:
+>>> # to join the following
+>>> # [[ 1, 2],
+>>> # [ 1, 3], x [[1, 5],
+>>> # [ 2, 3]] [4, 7]]
+>>> # only first two rows in left will be joined
+
>>> r1 = np.array([(1, 2), (1,3), (2,3)], dtype=[('seq', 'i4'), ('score', 'i4')])
+>>> r2 = np.array([(1, 5), (4,7)], dtype=[('seq', 'i4'), ('age', 'i4')])
+>>> joined = join_by_left('seq', r1, r2)
+>>> print(joined)
+[(1, 2, 5) (1, 3, 5) (2, 3, --)]
+
>>> print(joined.dtype)
+(numpy.record, [('seq', '<i4'), ('score', '<i4'), ('age', '<i4')])
+
>>> joined[2][2]
+masked
+
>>> joined.tolist()[2][2] == None
+True
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
key |
+ + | join关键字 |
+ required | +
r1 |
+ + | 数据集1 |
+ required | +
r2 |
+ + | 数据集2 |
+ required | +
Returns:
+Type | +Description | +
---|---|
+ | a numpy array |
+
omicron/extensions/np.py
def join_by_left(key, r1, r2, mask=True):
+ """左连接 `r1`, `r2` by `key`
+
+ 如果`r1`中存在`r2`中没有的行,则该行对应的`r2`中的那些字段将被mask,或者填充随机数。
+ same as numpy.lib.recfunctions.join_by(key, r1, r2, jointype='leftouter'), but allows r1 have duplicate keys
+
+ [Reference: stackoverflow](https://stackoverflow.com/a/53261882/13395693)
+
+ Examples:
+ >>> # to join the following
+ >>> # [[ 1, 2],
+ >>> # [ 1, 3], x [[1, 5],
+ >>> # [ 2, 3]] [4, 7]]
+ >>> # only first two rows in left will be joined
+
+ >>> r1 = np.array([(1, 2), (1,3), (2,3)], dtype=[('seq', 'i4'), ('score', 'i4')])
+ >>> r2 = np.array([(1, 5), (4,7)], dtype=[('seq', 'i4'), ('age', 'i4')])
+ >>> joined = join_by_left('seq', r1, r2)
+ >>> print(joined)
+ [(1, 2, 5) (1, 3, 5) (2, 3, --)]
+
+ >>> print(joined.dtype)
+ (numpy.record, [('seq', '<i4'), ('score', '<i4'), ('age', '<i4')])
+
+ >>> joined[2][2]
+ masked
+
+ >>> joined.tolist()[2][2] == None
+ True
+
+ Args:
+ key : join关键字
+ r1 : 数据集1
+ r2 : 数据集2
+
+ Returns:
+ a numpy array
+ """
+ # figure out the dtype of the result array
+ descr1 = r1.dtype.descr
+ descr2 = [d for d in r2.dtype.descr if d[0] not in r1.dtype.names]
+ descrm = descr1 + descr2
+
+ # figure out the fields we'll need from each array
+ f1 = [d[0] for d in descr1]
+ f2 = [d[0] for d in descr2]
+
+ # cache the number of columns in f1
+ ncol1 = len(f1)
+
+ # get a dict of the rows of r2 grouped by key
+ rows2 = {}
+ for row2 in r2:
+ rows2.setdefault(row2[key], []).append(row2)
+
+ # figure out how many rows will be in the result
+ nrowm = 0
+ for k1 in r1[key]:
+ if k1 in rows2:
+ nrowm += len(rows2[k1])
+ else:
+ nrowm += 1
+
+ # allocate the return array
+ # ret = np.full((nrowm, ), fill, dtype=descrm)
+ _ret = np.recarray(nrowm, dtype=descrm)
+ if mask:
+ ret = np.ma.array(_ret, mask=True)
+ else:
+ ret = _ret
+
+ # merge the data into the return array
+ i = 0
+ for row1 in r1:
+ if row1[key] in rows2:
+ for row2 in rows2[row1[key]]:
+ ret[i] = tuple(row1[f1]) + tuple(row2[f2])
+ i += 1
+ else:
+ for j in range(ncol1):
+ ret[i][j] = row1[j]
+ i += 1
+
+ return ret
+
numpy_append_fields(base, names, data, dtypes)
+
+
+¶给现有的数组base
增加新的字段
实现了numpy.lib.recfunctions.rec_append_fields
的功能。提供这个功能,是因为rec_append_fields
不能处理data
元素的类型为Object的情况。
新增的数据列将顺序排列在其它列的右边。
+ +Examples:
+>>> # 新增单个字段
+>>> import numpy
+>>> old = np.array([i for i in range(3)], dtype=[('col1', '<f4')])
+>>> new_list = [2 * i for i in range(3)]
+>>> res = numpy_append_fields(old, 'new_col', new_list, [('new_col', '<f4')])
+>>> print(res)
+...
+[(0., 0.) (1., 2.) (2., 4.)]
+
>>> # 新增多个字段
+>>> data = [res['col1'].tolist(), res['new_col'].tolist()]
+>>> print(numpy_append_fields(old, ('col3', 'col4'), data, [('col3', '<f4'), ('col4', '<f4')]))
+...
+[(0., 0., 0.) (1., 1., 2.) (2., 2., 4.)]
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
base |
+ [numpy.array] |
+ 基础数组 |
+ required | +
names |
+ [type] |
+ 新增字段的名字,可以是字符串(单字段的情况),也可以是字符串列表 |
+ required | +
data |
+ list |
+ 增加的字段的数据,list类型 |
+ required | +
dtypes |
+ [type] |
+ 新增字段的dtype |
+ required | +
omicron/extensions/np.py
def numpy_append_fields(
+ base: np.ndarray, names: Union[str, List[str]], data: List, dtypes: List
+) -> np.ndarray:
+ """给现有的数组`base`增加新的字段
+
+ 实现了`numpy.lib.recfunctions.rec_append_fields`的功能。提供这个功能,是因为`rec_append_fields`不能处理`data`元素的类型为Object的情况。
+
+ 新增的数据列将顺序排列在其它列的右边。
+
+ Example:
+ >>> # 新增单个字段
+ >>> import numpy
+ >>> old = np.array([i for i in range(3)], dtype=[('col1', '<f4')])
+ >>> new_list = [2 * i for i in range(3)]
+ >>> res = numpy_append_fields(old, 'new_col', new_list, [('new_col', '<f4')])
+ >>> print(res)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ [(0., 0.) (1., 2.) (2., 4.)]
+
+ >>> # 新增多个字段
+ >>> data = [res['col1'].tolist(), res['new_col'].tolist()]
+ >>> print(numpy_append_fields(old, ('col3', 'col4'), data, [('col3', '<f4'), ('col4', '<f4')]))
+ ... # doctest: +NORMALIZE_WHITESPACE
+ [(0., 0., 0.) (1., 1., 2.) (2., 2., 4.)]
+
+ Args:
+ base ([numpy.array]): 基础数组
+ names ([type]): 新增字段的名字,可以是字符串(单字段的情况),也可以是字符串列表
+ data (list): 增加的字段的数据,list类型
+ dtypes ([type]): 新增字段的dtype
+ """
+ if isinstance(names, str):
+ names = [names]
+ data = [data]
+
+ result = np.empty(base.shape, dtype=base.dtype.descr + dtypes)
+ for col in base.dtype.names:
+ result[col] = base[col]
+
+ for i in range(len(names)):
+ result[names[i]] = data[i]
+
+ return result
+
remove_nan(ts)
+
+
+¶从ts
中去除NaN
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def remove_nan(ts: np.ndarray) -> np.ndarray:
+ """从`ts`中去除NaN
+
+ Args:
+ ts (np.array): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ return ts[~np.isnan(ts.astype(float))]
+
replace_zero(ts, replacement=None)
+
+
+¶将ts中的0替换为前值, 处理volume数据时常用用到
+如果提供了replacement, 则替换为replacement
+ +omicron/extensions/np.py
def replace_zero(ts: np.ndarray, replacement=None) -> np.ndarray:
+ """将ts中的0替换为前值, 处理volume数据时常用用到
+
+ 如果提供了replacement, 则替换为replacement
+
+ """
+ if replacement is not None:
+ return np.where(ts == 0, replacement, ts)
+
+ if np.all(ts == 0):
+ raise ValueError("all of ts are 0")
+
+ if ts[0] == 0:
+ idx = np.argwhere(ts != 0)[0]
+ ts[0] = ts[idx]
+
+ mask = ts == 0
+ idx = np.where(~mask, np.arange(mask.size), 0)
+ np.maximum.accumulate(idx, out=idx)
+ return ts[idx]
+
rolling(x, win, func)
+
+
+¶对序列x
进行窗口滑动计算。
如果func
要实现的功能是argmax, argmin, max, mean, median, min, rank, std, sum, var等,move_argmax,请使用bottleneck中的move_argmin, move_max, move_mean, move_median, move_min move_rank, move_std, move_sum, move_var。这些函数的性能更好。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ [type] |
+ [description] |
+ required | +
win |
+ [type] |
+ [description] |
+ required | +
func |
+ [type] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
[type] |
+ [description] |
+
omicron/extensions/np.py
def rolling(x, win, func):
+ """对序列`x`进行窗口滑动计算。
+
+ 如果`func`要实现的功能是argmax, argmin, max, mean, median, min, rank, std, sum, var等,move_argmax,请使用bottleneck中的move_argmin, move_max, move_mean, move_median, move_min move_rank, move_std, move_sum, move_var。这些函数的性能更好。
+
+ Args:
+ x ([type]): [description]
+ win ([type]): [description]
+ func ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ results = []
+ for subarray in sliding_window_view(x, window_shape=win):
+ results.append(func(subarray))
+
+ return np.array(results)
+
shift(arr, start, offset)
+
+
+¶在numpy数组arr中,找到start(或者最接近的一个),取offset对应的元素。
+要求arr
已排序。offset
为正,表明向后移位;offset
为负,表明向前移位
Examples:
+>>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+>>> shift(arr, 20050104, 1)
+20050105
+
>>> shift(arr, 20050105, -1)
+20050104
+
>>> # 起始点已右越界,且向右shift,返回起始点
+>>> shift(arr, 20050120, 1)
+20050120
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ + | 已排序的数组 |
+ required | +
start |
+ + | numpy可接受的数据类型 |
+ required | +
offset |
+ int |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
+ | 移位后得到的元素值 |
+
omicron/extensions/np.py
def shift(arr, start, offset):
+ """在numpy数组arr中,找到start(或者最接近的一个),取offset对应的元素。
+
+ 要求`arr`已排序。`offset`为正,表明向后移位;`offset`为负,表明向前移位
+
+ Examples:
+ >>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+ >>> shift(arr, 20050104, 1)
+ 20050105
+
+ >>> shift(arr, 20050105, -1)
+ 20050104
+
+ >>> # 起始点已右越界,且向右shift,返回起始点
+ >>> shift(arr, 20050120, 1)
+ 20050120
+
+
+ Args:
+ arr : 已排序的数组
+ start : numpy可接受的数据类型
+ offset (int): [description]
+
+ Returns:
+ 移位后得到的元素值
+ """
+ pos = np.searchsorted(arr, start, side="right")
+
+ if pos + offset - 1 >= len(arr):
+ return start
+ else:
+ return arr[pos + offset - 1]
+
smallest_n_argpos(ts, n)
+
+
+¶get smallest n (min->max) elements and return argpos which its value ordered in ascent
+ +Examples:
+>>> smallest_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+array([8, 7])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ 输入的数组 |
+ required | +
n |
+ int |
+ 取最小的n个元素 |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def smallest_n_argpos(ts: np.array, n: int) -> np.array:
+ """get smallest n (min->max) elements and return argpos which its value ordered in ascent
+
+ Example:
+ >>> smallest_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+ array([8, 7])
+
+ Args:
+ ts (np.array): 输入的数组
+ n (int): 取最小的n个元素
+
+ Returns:
+ np.array: [description]
+ """
+ return np.argsort(ts)[:n]
+
to_pydatetime(tm)
+
+
+¶将numpy.datetime64对象转换成为python的datetime对象
+numpy.ndarray.item()方法可用以将任何numpy对象转换成python对象,推荐在任何适用的地方使用.item()方法,而不是本方法。示例: +
1 +2 +3 +4 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | the input numpy datetime object |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.datetime |
+ python datetime object |
+
.. deprecated:: 2.0.0 use tm.item()
instead
omicron/extensions/np.py
@deprecated("2.0.0", details="use `tm.item()` instead")
+def to_pydatetime(tm: np.datetime64) -> datetime.datetime:
+ """将numpy.datetime64对象转换成为python的datetime对象
+
+ numpy.ndarray.item()方法可用以将任何numpy对象转换成python对象,推荐在任何适用的地方使用.item()方法,而不是本方法。示例:
+ ```
+ arr = np.array(['2022-09-08', '2022-09-09'], dtype='datetime64[s]')
+ arr.item(0) # output is datetime.datetime(2022, 9, 8, 0, 0)
+
+ arr[1].item() # output is datetime.datetime(2022, 9, 9, 0, 0)
+ ```
+
+ Args:
+ tm : the input numpy datetime object
+
+ Returns:
+ python datetime object
+ """
+ unix_epoch = np.datetime64(0, "s")
+ one_second = np.timedelta64(1, "s")
+ seconds_since_epoch = (tm - unix_epoch) / one_second
+
+ return datetime.datetime.utcfromtimestamp(seconds_since_epoch)
+
top_n_argpos(ts, n)
+
+
+¶get top n (max->min) elements and return argpos which its value ordered in descent
+ +Examples:
+>>> top_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+array([3, 4])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
n |
+ int |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def top_n_argpos(ts: np.array, n: int) -> np.array:
+ """get top n (max->min) elements and return argpos which its value ordered in descent
+
+ Example:
+ >>> top_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+ array([3, 4])
+
+ Args:
+ ts (np.array): [description]
+ n (int): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ ts_ = np.copy(ts)
+ ts_[np.isnan(ts_)] = -np.inf
+ return np.argsort(ts_)[-n:][::-1]
+
usage: +
1 |
|
Omicron提供数据持久化、时间(日历、triggers)、行情数据model、基础运算和基础量化因子
+ + + +close()
+
+
+ async
+
+
+¶关闭与缓存的连接
+ +omicron/__init__.py
async def close():
+ """关闭与缓存的连接"""
+
+ try:
+ await cache.close()
+ except Exception as e: # noqa
+ pass
+
init(app_cache=5)
+
+
+ async
+
+
+¶初始化Omicron
+初始化influxDB, 缓存等连接, 并加载日历和证券列表
+上述初始化的连接,应该在程序退出时,通过调用close()
关闭
omicron/__init__.py
async def init(app_cache: int = 5):
+ """初始化Omicron
+
+ 初始化influxDB, 缓存等连接, 并加载日历和证券列表
+
+ 上述初始化的连接,应该在程序退出时,通过调用`close()`关闭
+ """
+ global cache
+
+ await cache.init(app=app_cache)
+ await tf.init()
+
+ from omicron.models.security import Security
+
+ await Security.init()
+
decimals
+
+
+
+¶math_round(x, digits)
+
+
+¶由于浮点数的表示问题,很多语言的round函数与数学上的round函数不一致。下面的函数结果与数学上的一致。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ float |
+ 要进行四舍五入的数字 |
+ required | +
digits |
+ int |
+ 小数点后保留的位数 |
+ required | +
omicron/extensions/decimals.py
def math_round(x: float, digits: int):
+ """由于浮点数的表示问题,很多语言的round函数与数学上的round函数不一致。下面的函数结果与数学上的一致。
+
+ Args:
+ x: 要进行四舍五入的数字
+ digits: 小数点后保留的位数
+
+ """
+
+ return int(x * (10**digits) + copysign(0.5, x)) / (10**digits)
+
price_equal(x, y)
+
+
+¶判断股价是否相等
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ + | 价格1 |
+ required | +
y |
+ + | 价格2 |
+ required | +
Returns:
+Type | +Description | +
---|---|
bool |
+ 如果相等则返回True,否则返回False |
+
omicron/extensions/decimals.py
def price_equal(x: float, y: float) -> bool:
+ """判断股价是否相等
+
+ Args:
+ x : 价格1
+ y : 价格2
+
+ Returns:
+ 如果相等则返回True,否则返回False
+ """
+ return abs(math_round(x, 2) - math_round(y, 2)) < 1e-2
+
np
+
+
+
+¶Extension function related to numpy
+ + + +array_math_round(arr, digits)
+
+
+¶将一维数组arr的数据进行四舍五入
+numpy.around的函数并不是数学上的四舍五入,对1.5和2.5进行round的结果都会变成2,在金融领域计算中,我们必须使用数学意义上的四舍五入。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ ArrayLike |
+ 输入数组 |
+ required | +
digits |
+ int |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ 四舍五入后的一维数组 |
+
omicron/extensions/np.py
def array_math_round(arr: Union[float, ArrayLike], digits: int) -> np.ndarray:
+ """将一维数组arr的数据进行四舍五入
+
+ numpy.around的函数并不是数学上的四舍五入,对1.5和2.5进行round的结果都会变成2,在金融领域计算中,我们必须使用数学意义上的四舍五入。
+
+ Args:
+ arr (ArrayLike): 输入数组
+ digits (int):
+
+ Returns:
+ np.ndarray: 四舍五入后的一维数组
+ """
+ # 如果是单个元素,则直接返回
+ if isinstance(arr, float):
+ return decimals.math_round(arr, digits)
+
+ f = np.vectorize(lambda x: decimals.math_round(x, digits))
+ return f(arr)
+
array_price_equal(price1, price2)
+
+
+¶判断两个价格数组是否相等
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
price1 |
+ ArrayLike |
+ 价格数组 |
+ required | +
price2 |
+ ArrayLike |
+ 价格数组 |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ 判断结果 |
+
omicron/extensions/np.py
def array_price_equal(price1: ArrayLike, price2: ArrayLike) -> np.ndarray:
+ """判断两个价格数组是否相等
+
+ Args:
+ price1 (ArrayLike): 价格数组
+ price2 (ArrayLike): 价格数组
+
+ Returns:
+ np.ndarray: 判断结果
+ """
+ price1 = array_math_round(price1, 2)
+ price2 = array_math_round(price2, 2)
+
+ return abs(price1 - price2) < 1e-2
+
bars_since(condition, default=None)
+
+
+¶Return the number of bars since condition
sequence was last True
,
+or if never, return default
.
1 +2 +3 |
|
omicron/extensions/np.py
def bars_since(condition: Sequence[bool], default=None) -> int:
+ """
+ Return the number of bars since `condition` sequence was last `True`,
+ or if never, return `default`.
+
+ >>> condition = [True, True, False]
+ >>> bars_since(condition)
+ 1
+ """
+ return next(compress(range(len(condition)), reversed(condition)), default)
+
bin_cut(arr, n)
+
+
+¶将数组arr切分成n份
+todo: use padding + reshape to boost performance
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ [type] |
+ [description] |
+ required | +
n |
+ [type] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
[type] |
+ [description] |
+
omicron/extensions/np.py
def bin_cut(arr: list, n: int):
+ """将数组arr切分成n份
+
+ todo: use padding + reshape to boost performance
+ Args:
+ arr ([type]): [description]
+ n ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ result = [[] for i in range(n)]
+
+ for i, e in enumerate(arr):
+ result[i % n].append(e)
+
+ return [e for e in result if len(e)]
+
count_between(arr, start, end)
+
+
+¶计算数组中,start
元素与end
元素之间共有多少个元素
要求arr必须是已排序。计算结果会包含区间边界点。
+ +Examples:
+>>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+>>> count_between(arr, 20050104, 20050111)
+6
+
>>> count_between(arr, 20050104, 20050109)
+4
+
omicron/extensions/np.py
def count_between(arr, start, end):
+ """计算数组中,`start`元素与`end`元素之间共有多少个元素
+
+ 要求arr必须是已排序。计算结果会包含区间边界点。
+
+ Examples:
+ >>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+ >>> count_between(arr, 20050104, 20050111)
+ 6
+
+ >>> count_between(arr, 20050104, 20050109)
+ 4
+ """
+ pos_start = np.searchsorted(arr, start, side="right")
+ pos_end = np.searchsorted(arr, end, side="right")
+
+ counter = pos_end - pos_start + 1
+ if start < arr[0]:
+ counter -= 1
+ if end > arr[-1]:
+ counter -= 1
+
+ return counter
+
dataframe_to_structured_array(df, dtypes=None)
+
+
+¶convert dataframe (with all columns, and index possibly) to numpy structured arrays
+len(dtypes)
should be either equal to len(df.columns)
or len(df.columns) + 1
. In the later case, it implies to include df.index
into converted array.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
df |
+ DataFrame |
+ the one needs to be converted |
+ required | +
dtypes |
+ List[Tuple] |
+ Defaults to None. If it's |
+ None |
+
Returns:
+Type | +Description | +
---|---|
ArrayLike |
+ [description] |
+
omicron/extensions/np.py
def dataframe_to_structured_array(
+ df: DataFrame, dtypes: List[Tuple] = None
+) -> ArrayLike:
+ """convert dataframe (with all columns, and index possibly) to numpy structured arrays
+
+ `len(dtypes)` should be either equal to `len(df.columns)` or `len(df.columns) + 1`. In the later case, it implies to include `df.index` into converted array.
+
+ Args:
+ df: the one needs to be converted
+ dtypes: Defaults to None. If it's `None`, then dtypes of `df` is used, in such case, the `index` of `df` will not be converted.
+
+ Returns:
+ ArrayLike: [description]
+ """
+ v = df
+ if dtypes is not None:
+ dtypes_in_dict = {key: value for key, value in dtypes}
+
+ col_len = len(df.columns)
+ if len(dtypes) == col_len + 1:
+ v = df.reset_index()
+
+ rename_index_to = set(dtypes_in_dict.keys()).difference(set(df.columns))
+ v.rename(columns={"index": list(rename_index_to)[0]}, inplace=True)
+ elif col_len != len(dtypes):
+ raise ValueError(
+ f"length of dtypes should be either {col_len} or {col_len + 1}, is {len(dtypes)}"
+ )
+
+ # re-arrange order of dtypes, in order to align with df.columns
+ dtypes = []
+ for name in v.columns:
+ dtypes.append((name, dtypes_in_dict[name]))
+ else:
+ dtypes = df.dtypes
+
+ return np.array(np.rec.fromrecords(v.values), dtype=dtypes)
+
dict_to_numpy_array(d, dtype)
+
+
+¶convert dictionary to numpy array
+ +Examples:
+ +++ +++++d = {"aaron": 5, "jack": 6} +dtype = [("name", "S8"), ("score", "<i4")] +dict_to_numpy_array(d, dtype) +array([(b'aaron', 5), (b'jack', 6)], + dtype=[('name', 'S8'), ('score', '<i4')])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
d |
+ dict |
+ [description] |
+ required | +
dtype |
+ List[Tuple] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def dict_to_numpy_array(d: dict, dtype: List[Tuple]) -> np.array:
+ """convert dictionary to numpy array
+
+ Examples:
+
+ >>> d = {"aaron": 5, "jack": 6}
+ >>> dtype = [("name", "S8"), ("score", "<i4")]
+ >>> dict_to_numpy_array(d, dtype)
+ array([(b'aaron', 5), (b'jack', 6)],
+ dtype=[('name', 'S8'), ('score', '<i4')])
+
+ Args:
+ d (dict): [description]
+ dtype (List[Tuple]): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ return np.fromiter(d.items(), dtype=dtype, count=len(d))
+
fill_nan(ts)
+
+
+¶将ts中的NaN替换为其前值
+如果ts起头的元素为NaN,则用第一个非NaN元素替换。
+如果所有元素都为NaN,则无法替换。
+ +Examples:
+>>> arr = np.arange(6, dtype=np.float32)
+>>> arr[3:5] = np.NaN
+>>> fill_nan(arr)
+...
+array([0., 1., 2., 2., 2., 5.], dtype=float32)
+
>>> arr = np.arange(6, dtype=np.float32)
+>>> arr[0:2] = np.nan
+>>> fill_nan(arr)
+...
+array([2., 2., 2., 3., 4., 5.], dtype=float32)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
omicron/extensions/np.py
def fill_nan(ts: np.ndarray):
+ """将ts中的NaN替换为其前值
+
+ 如果ts起头的元素为NaN,则用第一个非NaN元素替换。
+
+ 如果所有元素都为NaN,则无法替换。
+
+ Example:
+ >>> arr = np.arange(6, dtype=np.float32)
+ >>> arr[3:5] = np.NaN
+ >>> fill_nan(arr)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ array([0., 1., 2., 2., 2., 5.], dtype=float32)
+
+ >>> arr = np.arange(6, dtype=np.float32)
+ >>> arr[0:2] = np.nan
+ >>> fill_nan(arr)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ array([2., 2., 2., 3., 4., 5.], dtype=float32)
+
+ Args:
+ ts (np.array): [description]
+ """
+ if np.all(np.isnan(ts)):
+ raise ValueError("all of ts are NaN")
+
+ if ts[0] is None or math.isnan(ts[0]):
+ idx = np.argwhere(~np.isnan(ts))[0]
+ ts[0] = ts[idx]
+
+ mask = np.isnan(ts)
+ idx = np.where(~mask, np.arange(mask.size), 0)
+ np.maximum.accumulate(idx, out=idx)
+ return ts[idx]
+
find_runs(x)
+
+
+¶Find runs of consecutive items in an array.
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ ArrayLike |
+ the sequence to find runs in |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple[np.ndarray, np.ndarray, np.ndarray] |
+ A tuple of unique values, start indices, and length of runs |
+
omicron/extensions/np.py
def find_runs(x: ArrayLike) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Find runs of consecutive items in an array.
+
+ Args:
+ x: the sequence to find runs in
+
+ Returns:
+ A tuple of unique values, start indices, and length of runs
+ """
+
+ # ensure array
+ x = np.asanyarray(x)
+ if x.ndim != 1:
+ raise ValueError("only 1D array supported")
+ n = x.shape[0]
+
+ # handle empty array
+ if n == 0:
+ return np.array([]), np.array([]), np.array([])
+
+ else:
+ # find run starts
+ loc_run_start = np.empty(n, dtype=bool)
+ loc_run_start[0] = True
+ np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
+ run_starts = np.nonzero(loc_run_start)[0]
+
+ # find run values
+ run_values = x[loc_run_start]
+
+ # find run lengths
+ run_lengths = np.diff(np.append(run_starts, n))
+
+ return run_values, run_starts, run_lengths
+
floor(arr, item)
+
+
+¶在数据arr中,找到小于等于item的那一个值。如果item小于所有arr元素的值,返回arr[0];如果item +大于所有arr元素的值,返回arr[-1]
+与minute_frames_floor
不同的是,本函数不做回绕与进位.
Examples:
+>>> a = [3, 6, 9]
+>>> floor(a, -1)
+3
+>>> floor(a, 9)
+9
+>>> floor(a, 10)
+9
+>>> floor(a, 4)
+3
+>>> floor(a,10)
+9
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ + | + | required | +
item |
+ + | + | required | +
omicron/extensions/np.py
def floor(arr, item):
+ """
+ 在数据arr中,找到小于等于item的那一个值。如果item小于所有arr元素的值,返回arr[0];如果item
+ 大于所有arr元素的值,返回arr[-1]
+
+ 与`minute_frames_floor`不同的是,本函数不做回绕与进位.
+
+ Examples:
+ >>> a = [3, 6, 9]
+ >>> floor(a, -1)
+ 3
+ >>> floor(a, 9)
+ 9
+ >>> floor(a, 10)
+ 9
+ >>> floor(a, 4)
+ 3
+ >>> floor(a,10)
+ 9
+
+ Args:
+ arr:
+ item:
+
+ Returns:
+
+ """
+ if item < arr[0]:
+ return arr[0]
+ index = np.searchsorted(arr, item, side="right")
+ return arr[index - 1]
+
join_by_left(key, r1, r2, mask=True)
+
+
+¶左连接 r1
, r2
by key
如果r1
中存在r2
中没有的行,则该行对应的r2
中的那些字段将被mask,或者填充随机数。
+same as numpy.lib.recfunctions.join_by(key, r1, r2, jointype='leftouter'), but allows r1 have duplicate keys
Examples:
+>>> # to join the following
+>>> # [[ 1, 2],
+>>> # [ 1, 3], x [[1, 5],
+>>> # [ 2, 3]] [4, 7]]
+>>> # only first two rows in left will be joined
+
>>> r1 = np.array([(1, 2), (1,3), (2,3)], dtype=[('seq', 'i4'), ('score', 'i4')])
+>>> r2 = np.array([(1, 5), (4,7)], dtype=[('seq', 'i4'), ('age', 'i4')])
+>>> joined = join_by_left('seq', r1, r2)
+>>> print(joined)
+[(1, 2, 5) (1, 3, 5) (2, 3, --)]
+
>>> print(joined.dtype)
+(numpy.record, [('seq', '<i4'), ('score', '<i4'), ('age', '<i4')])
+
>>> joined[2][2]
+masked
+
>>> joined.tolist()[2][2] == None
+True
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
key |
+ + | join关键字 |
+ required | +
r1 |
+ + | 数据集1 |
+ required | +
r2 |
+ + | 数据集2 |
+ required | +
Returns:
+Type | +Description | +
---|---|
+ | a numpy array |
+
omicron/extensions/np.py
def join_by_left(key, r1, r2, mask=True):
+ """左连接 `r1`, `r2` by `key`
+
+ 如果`r1`中存在`r2`中没有的行,则该行对应的`r2`中的那些字段将被mask,或者填充随机数。
+ same as numpy.lib.recfunctions.join_by(key, r1, r2, jointype='leftouter'), but allows r1 have duplicate keys
+
+ [Reference: stackoverflow](https://stackoverflow.com/a/53261882/13395693)
+
+ Examples:
+ >>> # to join the following
+ >>> # [[ 1, 2],
+ >>> # [ 1, 3], x [[1, 5],
+ >>> # [ 2, 3]] [4, 7]]
+ >>> # only first two rows in left will be joined
+
+ >>> r1 = np.array([(1, 2), (1,3), (2,3)], dtype=[('seq', 'i4'), ('score', 'i4')])
+ >>> r2 = np.array([(1, 5), (4,7)], dtype=[('seq', 'i4'), ('age', 'i4')])
+ >>> joined = join_by_left('seq', r1, r2)
+ >>> print(joined)
+ [(1, 2, 5) (1, 3, 5) (2, 3, --)]
+
+ >>> print(joined.dtype)
+ (numpy.record, [('seq', '<i4'), ('score', '<i4'), ('age', '<i4')])
+
+ >>> joined[2][2]
+ masked
+
+ >>> joined.tolist()[2][2] == None
+ True
+
+ Args:
+ key : join关键字
+ r1 : 数据集1
+ r2 : 数据集2
+
+ Returns:
+ a numpy array
+ """
+ # figure out the dtype of the result array
+ descr1 = r1.dtype.descr
+ descr2 = [d for d in r2.dtype.descr if d[0] not in r1.dtype.names]
+ descrm = descr1 + descr2
+
+ # figure out the fields we'll need from each array
+ f1 = [d[0] for d in descr1]
+ f2 = [d[0] for d in descr2]
+
+ # cache the number of columns in f1
+ ncol1 = len(f1)
+
+ # get a dict of the rows of r2 grouped by key
+ rows2 = {}
+ for row2 in r2:
+ rows2.setdefault(row2[key], []).append(row2)
+
+ # figure out how many rows will be in the result
+ nrowm = 0
+ for k1 in r1[key]:
+ if k1 in rows2:
+ nrowm += len(rows2[k1])
+ else:
+ nrowm += 1
+
+ # allocate the return array
+ # ret = np.full((nrowm, ), fill, dtype=descrm)
+ _ret = np.recarray(nrowm, dtype=descrm)
+ if mask:
+ ret = np.ma.array(_ret, mask=True)
+ else:
+ ret = _ret
+
+ # merge the data into the return array
+ i = 0
+ for row1 in r1:
+ if row1[key] in rows2:
+ for row2 in rows2[row1[key]]:
+ ret[i] = tuple(row1[f1]) + tuple(row2[f2])
+ i += 1
+ else:
+ for j in range(ncol1):
+ ret[i][j] = row1[j]
+ i += 1
+
+ return ret
+
numpy_append_fields(base, names, data, dtypes)
+
+
+¶给现有的数组base
增加新的字段
实现了numpy.lib.recfunctions.rec_append_fields
的功能。提供这个功能,是因为rec_append_fields
不能处理data
元素的类型为Object的情况。
新增的数据列将顺序排列在其它列的右边。
+ +Examples:
+>>> # 新增单个字段
+>>> import numpy
+>>> old = np.array([i for i in range(3)], dtype=[('col1', '<f4')])
+>>> new_list = [2 * i for i in range(3)]
+>>> res = numpy_append_fields(old, 'new_col', new_list, [('new_col', '<f4')])
+>>> print(res)
+...
+[(0., 0.) (1., 2.) (2., 4.)]
+
>>> # 新增多个字段
+>>> data = [res['col1'].tolist(), res['new_col'].tolist()]
+>>> print(numpy_append_fields(old, ('col3', 'col4'), data, [('col3', '<f4'), ('col4', '<f4')]))
+...
+[(0., 0., 0.) (1., 1., 2.) (2., 2., 4.)]
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
base |
+ [numpy.array] |
+ 基础数组 |
+ required | +
names |
+ [type] |
+ 新增字段的名字,可以是字符串(单字段的情况),也可以是字符串列表 |
+ required | +
data |
+ list |
+ 增加的字段的数据,list类型 |
+ required | +
dtypes |
+ [type] |
+ 新增字段的dtype |
+ required | +
omicron/extensions/np.py
def numpy_append_fields(
+ base: np.ndarray, names: Union[str, List[str]], data: List, dtypes: List
+) -> np.ndarray:
+ """给现有的数组`base`增加新的字段
+
+ 实现了`numpy.lib.recfunctions.rec_append_fields`的功能。提供这个功能,是因为`rec_append_fields`不能处理`data`元素的类型为Object的情况。
+
+ 新增的数据列将顺序排列在其它列的右边。
+
+ Example:
+ >>> # 新增单个字段
+ >>> import numpy
+ >>> old = np.array([i for i in range(3)], dtype=[('col1', '<f4')])
+ >>> new_list = [2 * i for i in range(3)]
+ >>> res = numpy_append_fields(old, 'new_col', new_list, [('new_col', '<f4')])
+ >>> print(res)
+ ... # doctest: +NORMALIZE_WHITESPACE
+ [(0., 0.) (1., 2.) (2., 4.)]
+
+ >>> # 新增多个字段
+ >>> data = [res['col1'].tolist(), res['new_col'].tolist()]
+ >>> print(numpy_append_fields(old, ('col3', 'col4'), data, [('col3', '<f4'), ('col4', '<f4')]))
+ ... # doctest: +NORMALIZE_WHITESPACE
+ [(0., 0., 0.) (1., 1., 2.) (2., 2., 4.)]
+
+ Args:
+ base ([numpy.array]): 基础数组
+ names ([type]): 新增字段的名字,可以是字符串(单字段的情况),也可以是字符串列表
+ data (list): 增加的字段的数据,list类型
+ dtypes ([type]): 新增字段的dtype
+ """
+ if isinstance(names, str):
+ names = [names]
+ data = [data]
+
+ result = np.empty(base.shape, dtype=base.dtype.descr + dtypes)
+ for col in base.dtype.names:
+ result[col] = base[col]
+
+ for i in range(len(names)):
+ result[names[i]] = data[i]
+
+ return result
+
remove_nan(ts)
+
+
+¶从ts
中去除NaN
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def remove_nan(ts: np.ndarray) -> np.ndarray:
+ """从`ts`中去除NaN
+
+ Args:
+ ts (np.array): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ return ts[~np.isnan(ts.astype(float))]
+
replace_zero(ts, replacement=None)
+
+
+¶将ts中的0替换为前值, 处理volume数据时常用用到
+如果提供了replacement, 则替换为replacement
+ +omicron/extensions/np.py
def replace_zero(ts: np.ndarray, replacement=None) -> np.ndarray:
+ """将ts中的0替换为前值, 处理volume数据时常用用到
+
+ 如果提供了replacement, 则替换为replacement
+
+ """
+ if replacement is not None:
+ return np.where(ts == 0, replacement, ts)
+
+ if np.all(ts == 0):
+ raise ValueError("all of ts are 0")
+
+ if ts[0] == 0:
+ idx = np.argwhere(ts != 0)[0]
+ ts[0] = ts[idx]
+
+ mask = ts == 0
+ idx = np.where(~mask, np.arange(mask.size), 0)
+ np.maximum.accumulate(idx, out=idx)
+ return ts[idx]
+
rolling(x, win, func)
+
+
+¶对序列x
进行窗口滑动计算。
如果func
要实现的功能是argmax, argmin, max, mean, median, min, rank, std, sum, var等,move_argmax,请使用bottleneck中的move_argmin, move_max, move_mean, move_median, move_min move_rank, move_std, move_sum, move_var。这些函数的性能更好。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
x |
+ [type] |
+ [description] |
+ required | +
win |
+ [type] |
+ [description] |
+ required | +
func |
+ [type] |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
[type] |
+ [description] |
+
omicron/extensions/np.py
def rolling(x, win, func):
+ """对序列`x`进行窗口滑动计算。
+
+ 如果`func`要实现的功能是argmax, argmin, max, mean, median, min, rank, std, sum, var等,move_argmax,请使用bottleneck中的move_argmin, move_max, move_mean, move_median, move_min move_rank, move_std, move_sum, move_var。这些函数的性能更好。
+
+ Args:
+ x ([type]): [description]
+ win ([type]): [description]
+ func ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+ results = []
+ for subarray in sliding_window_view(x, window_shape=win):
+ results.append(func(subarray))
+
+ return np.array(results)
+
shift(arr, start, offset)
+
+
+¶在numpy数组arr中,找到start(或者最接近的一个),取offset对应的元素。
+要求arr
已排序。offset
为正,表明向后移位;offset
为负,表明向前移位
Examples:
+>>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+>>> shift(arr, 20050104, 1)
+20050105
+
>>> shift(arr, 20050105, -1)
+20050104
+
>>> # 起始点已右越界,且向右shift,返回起始点
+>>> shift(arr, 20050120, 1)
+20050120
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
arr |
+ + | 已排序的数组 |
+ required | +
start |
+ + | numpy可接受的数据类型 |
+ required | +
offset |
+ int |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
+ | 移位后得到的元素值 |
+
omicron/extensions/np.py
def shift(arr, start, offset):
+ """在numpy数组arr中,找到start(或者最接近的一个),取offset对应的元素。
+
+ 要求`arr`已排序。`offset`为正,表明向后移位;`offset`为负,表明向前移位
+
+ Examples:
+ >>> arr = [20050104, 20050105, 20050106, 20050107, 20050110, 20050111]
+ >>> shift(arr, 20050104, 1)
+ 20050105
+
+ >>> shift(arr, 20050105, -1)
+ 20050104
+
+ >>> # 起始点已右越界,且向右shift,返回起始点
+ >>> shift(arr, 20050120, 1)
+ 20050120
+
+
+ Args:
+ arr : 已排序的数组
+ start : numpy可接受的数据类型
+ offset (int): [description]
+
+ Returns:
+ 移位后得到的元素值
+ """
+ pos = np.searchsorted(arr, start, side="right")
+
+ if pos + offset - 1 >= len(arr):
+ return start
+ else:
+ return arr[pos + offset - 1]
+
smallest_n_argpos(ts, n)
+
+
+¶get smallest n (min->max) elements and return argpos which its value ordered in ascent
+ +Examples:
+>>> smallest_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+array([8, 7])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ 输入的数组 |
+ required | +
n |
+ int |
+ 取最小的n个元素 |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def smallest_n_argpos(ts: np.array, n: int) -> np.array:
+ """get smallest n (min->max) elements and return argpos which its value ordered in ascent
+
+ Example:
+ >>> smallest_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+ array([8, 7])
+
+ Args:
+ ts (np.array): 输入的数组
+ n (int): 取最小的n个元素
+
+ Returns:
+ np.array: [description]
+ """
+ return np.argsort(ts)[:n]
+
to_pydatetime(tm)
+
+
+¶将numpy.datetime64对象转换成为python的datetime对象
+numpy.ndarray.item()方法可用以将任何numpy对象转换成python对象,推荐在任何适用的地方使用.item()方法,而不是本方法。示例: +
1 +2 +3 +4 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | the input numpy datetime object |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.datetime |
+ python datetime object |
+
.. deprecated:: 2.0.0 use tm.item()
instead
omicron/extensions/np.py
@deprecated("2.0.0", details="use `tm.item()` instead")
+def to_pydatetime(tm: np.datetime64) -> datetime.datetime:
+ """将numpy.datetime64对象转换成为python的datetime对象
+
+ numpy.ndarray.item()方法可用以将任何numpy对象转换成python对象,推荐在任何适用的地方使用.item()方法,而不是本方法。示例:
+ ```
+ arr = np.array(['2022-09-08', '2022-09-09'], dtype='datetime64[s]')
+ arr.item(0) # output is datetime.datetime(2022, 9, 8, 0, 0)
+
+ arr[1].item() # output is datetime.datetime(2022, 9, 9, 0, 0)
+ ```
+
+ Args:
+ tm : the input numpy datetime object
+
+ Returns:
+ python datetime object
+ """
+ unix_epoch = np.datetime64(0, "s")
+ one_second = np.timedelta64(1, "s")
+ seconds_since_epoch = (tm - unix_epoch) / one_second
+
+ return datetime.datetime.utcfromtimestamp(seconds_since_epoch)
+
top_n_argpos(ts, n)
+
+
+¶get top n (max->min) elements and return argpos which its value ordered in descent
+ +Examples:
+>>> top_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+array([3, 4])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
n |
+ int |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/extensions/np.py
def top_n_argpos(ts: np.array, n: int) -> np.array:
+ """get top n (max->min) elements and return argpos which its value ordered in descent
+
+ Example:
+ >>> top_n_argpos([np.nan, 4, 3, 9, 8, 5, 2, 1, 0, 6, 7], 2)
+ array([3, 4])
+
+ Args:
+ ts (np.array): [description]
+ n (int): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ ts_ = np.copy(ts)
+ ts_[np.isnan(ts_)] = -np.inf
+ return np.argsort(ts_)[-n:][::-1]
+
dingtalk
+
+
+
+¶
+DingTalkMessage
+
+
+
+¶钉钉的机器人消息推送类,封装了常用的消息类型以及加密算法 +需要在配置文件中配置钉钉的机器人的access_token +如果配置了加签,需要在配置文件中配置钉钉的机器人的secret +如果配置了自定义关键词,需要在配置文件中配置钉钉的机器人的keyword,多个关键词用英文逗号分隔 +全部的配置文件示例如下, 其中secret和keyword可以不配置, access_token必须配置 +notify: + dingtalk_access_token: xxxx + dingtalk_secret: xxxx
+ +omicron/notify/dingtalk.py
class DingTalkMessage:
+ """
+ 钉钉的机器人消息推送类,封装了常用的消息类型以及加密算法
+ 需要在配置文件中配置钉钉的机器人的access_token
+ 如果配置了加签,需要在配置文件中配置钉钉的机器人的secret
+ 如果配置了自定义关键词,需要在配置文件中配置钉钉的机器人的keyword,多个关键词用英文逗号分隔
+ 全部的配置文件示例如下, 其中secret和keyword可以不配置, access_token必须配置
+ notify:
+ dingtalk_access_token: xxxx
+ dingtalk_secret: xxxx
+ """
+
+ url = "https://oapi.dingtalk.com/robot/send"
+
+ @classmethod
+ def _get_access_token(cls):
+ """获取钉钉机器人的access_token"""
+ if hasattr(cfg.notify, "dingtalk_access_token"):
+ return cfg.notify.dingtalk_access_token
+ else:
+ logger.error(
+ "Dingtalk not configured, please add the following items:\n"
+ "notify:\n"
+ " dingtalk_access_token: xxxx\n"
+ " dingtalk_secret: xxxx\n"
+ )
+ raise ConfigError("dingtalk_access_token not found")
+
+ @classmethod
+ def _get_secret(cls):
+ """获取钉钉机器人的secret"""
+ if hasattr(cfg.notify, "dingtalk_secret"):
+ return cfg.notify.dingtalk_secret
+ else:
+ return None
+
+ @classmethod
+ def _get_url(cls):
+ """获取钉钉机器人的消息推送地址,将签名和时间戳拼接在url后面"""
+ access_token = cls._get_access_token()
+ url = f"{cls.url}?access_token={access_token}"
+ secret = cls._get_secret()
+ if secret:
+ timestamp, sign = cls._get_sign(secret)
+ url = f"{url}×tamp={timestamp}&sign={sign}"
+ return url
+
+ @classmethod
+ def _get_sign(cls, secret: str):
+ """获取签名发送给钉钉机器人"""
+ timestamp = str(round(time.time() * 1000))
+ secret_enc = secret.encode("utf-8")
+ string_to_sign = "{}\n{}".format(timestamp, secret)
+ string_to_sign_enc = string_to_sign.encode("utf-8")
+ hmac_code = hmac.new(
+ secret_enc, string_to_sign_enc, digestmod=hashlib.sha256
+ ).digest()
+ sign = urllib.parse.quote_plus(base64.b64encode(hmac_code))
+ return timestamp, sign
+
+ @classmethod
+ def _send(cls, msg):
+ """发送消息到钉钉机器人"""
+ url = cls._get_url()
+ response = httpx.post(url, json=msg, timeout=30)
+ if response.status_code != 200:
+ logger.error(
+ f"failed to send message, content: {msg}, response from Dingtalk: {response.content.decode()}"
+ )
+ return
+ rsp = json.loads(response.content)
+ if rsp.get("errcode") != 0:
+ logger.error(
+ f"failed to send message, content: {msg}, response from Dingtalk: {rsp}"
+ )
+ return response.content.decode()
+
+ @classmethod
+ async def _send_async(cls, msg):
+ """发送消息到钉钉机器人"""
+ url = cls._get_url()
+ async with httpx.AsyncClient() as client:
+ r = await client.post(url, json=msg, timeout=30)
+ if r.status_code != 200:
+ logger.error(
+ f"failed to send message, content: {msg}, response from Dingtalk: {r.content.decode()}"
+ )
+ return
+ rsp = json.loads(r.content)
+ if rsp.get("errcode") != 0:
+ logger.error(
+ f"failed to send message, content: {msg}, response from Dingtalk: {rsp}"
+ )
+ return r.content.decode()
+
+ @classmethod
+ @deprecated("2.0.0", details="use function `ding` instead")
+ def text(cls, content):
+ msg = {"text": {"content": content}, "msgtype": "text"}
+ return cls._send(msg)
+
text(cls, content)
+
+
+ classmethod
+
+
+¶.. deprecated:: 2.0.0 use function ding
instead
omicron/notify/dingtalk.py
@classmethod
+@deprecated("2.0.0", details="use function `ding` instead")
+def text(cls, content):
+ msg = {"text": {"content": content}, "msgtype": "text"}
+ return cls._send(msg)
+
ding(msg)
+
+
+¶发送消息到钉钉机器人
+支持发送纯文本消息和markdown格式的文本消息。如果要发送markdown格式的消息,请通过字典传入,必须包含包含"title"和"text"两个字段。更详细信息,请见钉钉开放平台文档
+必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。 +此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
msg |
+ Union[str, dict] |
+ 待发送消息。 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Awaitable |
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。 |
+
omicron/notify/dingtalk.py
def ding(msg: Union[str, dict]) -> Awaitable:
+ """发送消息到钉钉机器人
+
+ 支持发送纯文本消息和markdown格式的文本消息。如果要发送markdown格式的消息,请通过字典传入,必须包含包含"title"和"text"两个字段。更详细信息,请见[钉钉开放平台文档](https://open.dingtalk.com/document/orgapp-server/message-type)
+
+ ???+ Important
+ 必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。
+ 此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+
+ Args:
+ msg: 待发送消息。
+
+ Returns:
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。
+ """
+ if isinstance(msg, str):
+ msg_ = {"text": {"content": msg}, "msgtype": "text"}
+ elif isinstance(msg, dict):
+ msg_ = {
+ "msgtype": "markdown",
+ "markdown": {"title": msg["title"], "text": msg["text"]},
+ }
+ else:
+ raise TypeError
+
+ task = asyncio.create_task(DingTalkMessage._send_async(msg_))
+ return task
+
mail
+
+
+
+¶compose(subject, plain_txt=None, html=None, attachment=None)
+
+
+¶编写MIME邮件。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
subject |
+ str |
+ 邮件主题 |
+ required | +
plain_txt |
+ str |
+ 纯文本格式的邮件内容 |
+ None |
+
html |
+ str |
+ html格式的邮件内容. Defaults to None. |
+ None |
+
attachment |
+ str |
+ 附件文件名 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
EmailMessage |
+ MIME mail |
+
omicron/notify/mail.py
def compose(
+ subject: str, plain_txt: str = None, html: str = None, attachment: str = None
+) -> EmailMessage:
+ """编写MIME邮件。
+
+ Args:
+ subject (str): 邮件主题
+ plain_txt (str): 纯文本格式的邮件内容
+ html (str, optional): html格式的邮件内容. Defaults to None.
+ attachment (str, optional): 附件文件名
+ Returns:
+ MIME mail
+ """
+ msg = EmailMessage()
+
+ msg["Subject"] = subject
+
+ if html:
+ msg.preamble = plain_txt or ""
+ msg.set_content(html, subtype="html")
+ else:
+ assert plain_txt, "Either plain_txt or html is required."
+ msg.set_content(plain_txt)
+
+ if attachment:
+ ctype, encoding = mimetypes.guess_type(attachment)
+ if ctype is None or encoding is not None:
+ ctype = "application/octet-stream"
+
+ maintype, subtype = ctype.split("/", 1)
+ with open(attachment, "rb") as f:
+ msg.add_attachment(
+ f.read(), maintype=maintype, subtype=subtype, filename=attachment
+ )
+
+ return msg
+
mail_notify(subject=None, body=None, msg=None, html=False, receivers=None)
+
+
+¶发送邮件通知。
+发送者、接收者及邮件服务器等配置请通过cfg4py配置:
+1 +2 +3 +4 +5 |
|
MAIL_PASSWORD
来配置。
+subject/body与msg必须提供其一。
+必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。 +此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
msg |
+ EmailMessage |
+ [description]. Defaults to None. |
+ None |
+
subject |
+ str |
+ [description]. Defaults to None. |
+ None |
+
body |
+ str |
+ [description]. Defaults to None. |
+ None |
+
html |
+ bool |
+ body是否按html格式处理? Defaults to False. |
+ False |
+
receivers |
+ List[str], Optional |
+ 接收者信息。如果不提供,将使用预先配置的接收者信息。 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Awaitable |
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。 |
+
omicron/notify/mail.py
def mail_notify(
+ subject: str = None,
+ body: str = None,
+ msg: EmailMessage = None,
+ html=False,
+ receivers=None,
+) -> Awaitable:
+ """发送邮件通知。
+
+ 发送者、接收者及邮件服务器等配置请通过cfg4py配置:
+
+ ```
+ notify:
+ mail_from: aaron_yang@jieyu.ai
+ mail_to:
+ - code@jieyu.ai
+ mail_server: smtp.ym.163.com
+ ```
+ 验证密码请通过环境变量`MAIL_PASSWORD`来配置。
+
+ subject/body与msg必须提供其一。
+
+ ???+ Important
+ 必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。
+ 此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+
+ Args:
+ msg (EmailMessage, optional): [description]. Defaults to None.
+ subject (str, optional): [description]. Defaults to None.
+ body (str, optional): [description]. Defaults to None.
+ html (bool, optional): body是否按html格式处理? Defaults to False.
+ receivers (List[str], Optional): 接收者信息。如果不提供,将使用预先配置的接收者信息。
+
+ Returns:
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。
+ """
+ if all([msg is not None, subject or body]):
+ raise TypeError("msg参数与subject/body只能提供其中之一")
+ elif all([msg is None, subject is None, body is None]):
+ raise TypeError("必须提供msg参数或者subjecdt/body参数")
+
+ if msg is None:
+ if html:
+ msg = compose(subject, html=body)
+ else:
+ msg = compose(subject, plain_txt=body)
+
+ cfg = cfg4py.get_instance()
+ if not receivers:
+ receivers = cfg.notify.mail_to
+
+ password = os.environ.get("MAIL_PASSWORD")
+ return send_mail(
+ cfg.notify.mail_from, receivers, password, msg, host=cfg.notify.mail_server
+ )
+
send_mail(sender, receivers, password, msg=None, host=None, port=25, cc=None, bcc=None, subject=None, body=None, username=None)
+
+
+¶发送邮件通知。
+如果只发送简单的文本邮件,请使用 send_mail(sender, receivers, subject=subject, plain=plain)。如果要发送较复杂的带html和附件的邮件,请先调用compose()生成一个EmailMessage,然后再调用send_mail(sender, receivers, msg)来发送邮件。
+必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。 +此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sender |
+ str |
+ [description] |
+ required | +
receivers |
+ List[str] |
+ [description] |
+ required | +
msg |
+ EmailMessage |
+ [description]. Defaults to None. |
+ None |
+
host |
+ str |
+ [description]. Defaults to None. |
+ None |
+
port |
+ int |
+ [description]. Defaults to 25. |
+ 25 |
+
cc |
+ List[str] |
+ [description]. Defaults to None. |
+ None |
+
bcc |
+ List[str] |
+ [description]. Defaults to None. |
+ None |
+
subject |
+ str |
+ [description]. Defaults to None. |
+ None |
+
plain |
+ str |
+ [description]. Defaults to None. |
+ required | +
username |
+ str |
+ the username used to logon to mail server. if not provided, then |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Awaitable |
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。 |
+
omicron/notify/mail.py
@retry(aiosmtplib.errors.SMTPConnectError, tries=3, backoff=2, delay=30, logger=logger)
+def send_mail(
+ sender: str,
+ receivers: List[str],
+ password: str,
+ msg: EmailMessage = None,
+ host: str = None,
+ port: int = 25,
+ cc: List[str] = None,
+ bcc: List[str] = None,
+ subject: str = None,
+ body: str = None,
+ username: str = None,
+) -> Awaitable:
+ """发送邮件通知。
+
+ 如果只发送简单的文本邮件,请使用 send_mail(sender, receivers, subject=subject, plain=plain)。如果要发送较复杂的带html和附件的邮件,请先调用compose()生成一个EmailMessage,然后再调用send_mail(sender, receivers, msg)来发送邮件。
+
+ ???+ Important
+ 必须在异步线程(即运行asyncio loop的线程)中调用此方法,否则会抛出异常。
+ 此方法返回一个Awaitable,您可以等待它完成,也可以忽略返回值,此时它将作为一个后台任务执行,但完成的时间不确定。
+
+ Args:
+ sender (str): [description]
+ receivers (List[str]): [description]
+ msg (EmailMessage, optional): [description]. Defaults to None.
+ host (str, optional): [description]. Defaults to None.
+ port (int, optional): [description]. Defaults to 25.
+ cc (List[str], optional): [description]. Defaults to None.
+ bcc (List[str], optional): [description]. Defaults to None.
+ subject (str, optional): [description]. Defaults to None.
+ plain (str, optional): [description]. Defaults to None.
+ username (str, optional): the username used to logon to mail server. if not provided, then `sender` is used.
+
+ Returns:
+ 发送消息的后台任务。您可以使用此返回句柄来取消任务。
+ """
+ if all([msg is not None, subject is not None or body is not None]):
+ raise TypeError("msg参数与subject/body只能提供其中之一")
+ elif all([msg is None, subject is None, body is None]):
+ raise TypeError("必须提供msg参数或者subjecdt/body参数")
+
+ msg = msg or EmailMessage()
+
+ if isinstance(receivers, str):
+ receivers = [receivers]
+
+ msg["From"] = sender
+ msg["To"] = ", ".join(receivers)
+
+ if subject:
+ msg["subject"] = subject
+
+ if body:
+ msg.set_content(body)
+
+ if cc:
+ msg["Cc"] = ", ".join(cc)
+ if bcc:
+ msg["Bcc"] = ", ".join(bcc)
+
+ username = username or sender
+
+ if host is None:
+ host = sender.split("@")[-1]
+
+ task = asyncio.create_task(
+ aiosmtplib.send(
+ msg, hostname=host, port=port, username=sender, password=password
+ )
+ )
+
+ return task
+
Info
+Since 2.0.0.a76
+回测时,打印时间一般要求为回测当时的时间,而非系统时间。这个模块提供了改写日志时间的功能。
+使用方法:
+1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 |
|
1 |
|
使用本日志的核心是上述代码中的第3行和第9行,最后,在输出日志时加上date=...
,如第15行所示。
注意在第9行,通常是logging.getLogger(__nam__)
,而这里是BacktestLogger.getLogger(__name__)
如果上述调用中没有传入date
,则将使用调用时间,此时行为跟原日志系统一致。
Warning
+当调用logger.exception时,不能传入date参数。
+如果要通过配置文件来配置,可使用以下示例: +
1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 |
|
绘制K线图。
+注意示例需要在notebook中运行,否则无法生成图。
+1 +2 +3 +4 +5 |
|
这将生成下图: +
+默认地,将显示成交量和RSI指标两个副图。可以通过以下方式来定制: +
1 +2 +3 +4 +5 |
|
1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 |
|
1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 |
|
1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 |
|
+Candlestick
+
+
+
+¶omicron/plotting/candlestick.py
class Candlestick:
+ RED = "#FF4136"
+ GREEN = "#3DAA70"
+ TRANSPARENT = "rgba(0,0,0,0)"
+ LIGHT_GRAY = "rgba(0, 0, 0, 0.1)"
+ MA_COLORS = {
+ 5: "#1432F5",
+ 10: "#EB52F7",
+ 20: "#C0C0C0",
+ 30: "#882111",
+ 60: "#5E8E28",
+ 120: "#4294F7",
+ 250: "#F09937",
+ }
+
+ def __init__(
+ self,
+ bars: np.ndarray,
+ ma_groups: List[int] = None,
+ title: str = None,
+ show_volume=True,
+ show_rsi=True,
+ show_peaks=False,
+ width=None,
+ height=None,
+ **kwargs,
+ ):
+ """构造函数
+
+ Args:
+ bars: 行情数据
+ ma_groups: 均线组参数。比如[5, 10, 20]表明向k线图中添加5, 10, 20日均线。如果不提供,将从数组[5, 10, 20, 30, 60, 120, 250]中取直到与`len(bars) - 5`匹配的参数为止。比如bars长度为30,则将取[5, 10, 20]来绘制均线。
+ title: k线图的标题
+ show_volume: 是否显示成交量图
+ show_rsi: 是否显示RSI图。缺省显示参数为6的RSI图。
+ show_peaks: 是否标记检测出来的峰跟谷。
+ width: the width in 'px' units of the figure
+ height: the height in 'px' units of the figure
+ Keyword Args:
+ rsi_win int: default is 6
+ """
+ self.title = title
+ self.bars = bars
+ self.width = width
+ self.height = height
+
+ # traces for main area
+ self.main_traces = {}
+
+ # traces for indicator area
+ self.ind_traces = {}
+
+ self.ticks = self._format_tick(bars["frame"])
+ self._bar_close = array_math_round(bars["close"], 2).astype(np.float64)
+
+ # for every candlestick, it must contain a candlestick plot
+ cs = go.Candlestick(
+ x=self.ticks,
+ open=bars["open"],
+ high=bars["high"],
+ low=bars["low"],
+ close=self._bar_close,
+ line=dict({"width": 1}),
+ name="K线",
+ **kwargs,
+ )
+
+ # Set line and fill colors
+ cs.increasing.fillcolor = "rgba(255,255,255,0.9)"
+ cs.increasing.line.color = self.RED
+ cs.decreasing.fillcolor = self.GREEN
+ cs.decreasing.line.color = self.GREEN
+
+ self.main_traces["ohlc"] = cs
+
+ if show_volume:
+ self.add_indicator("volume")
+
+ if show_peaks:
+ self.add_main_trace("peaks")
+
+ if show_rsi:
+ self.add_indicator("rsi", win=kwargs.get("rsi_win", 6))
+
+ # 增加均线
+ if ma_groups is None:
+ nbars = len(bars)
+ if nbars < 9:
+ ma_groups = []
+ else:
+ groups = np.array([5, 10, 20, 30, 60, 120, 250])
+ idx = max(np.argwhere(groups < (nbars - 5))).item() + 1
+ ma_groups = groups[:idx]
+
+ for win in ma_groups:
+ name = f"ma{win}"
+ if win > len(bars):
+ continue
+ ma = moving_average(self._bar_close, win)
+ line = go.Scatter(
+ y=ma,
+ x=self.ticks,
+ name=name,
+ line=dict(width=1, color=self.MA_COLORS.get(win)),
+ )
+ self.main_traces[name] = line
+
+ @property
+ def figure(self):
+ """返回一个figure对象"""
+ rows = len(self.ind_traces) + 1
+ specs = [[{"secondary_y": False}]] * rows
+ specs[0][0]["secondary_y"] = True
+
+ row_heights = [0.7, *([0.3 / (rows - 1)] * (rows - 1))]
+ print(row_heights)
+ cols = 1
+
+ fig = make_subplots(
+ rows=rows,
+ cols=cols,
+ shared_xaxes=True,
+ vertical_spacing=0.1,
+ subplot_titles=(self.title, *self.ind_traces.keys()),
+ row_heights=row_heights,
+ specs=specs,
+ )
+
+ for _, trace in self.main_traces.items():
+ fig.add_trace(trace, row=1, col=1)
+
+ for i, (_, trace) in enumerate(self.ind_traces.items()):
+ fig.add_trace(trace, row=i + 2, col=1)
+
+ ymin = np.min(self.bars["low"])
+ ymax = np.max(self.bars["high"])
+
+ ylim = [ymin * 0.95, ymax * 1.05]
+
+ # 显示十字光标
+ fig.update_xaxes(
+ showgrid=False,
+ showspikes=True,
+ spikemode="across",
+ spikesnap="cursor",
+ spikecolor="grey",
+ spikedash="solid",
+ spikethickness=1,
+ )
+
+ fig.update_yaxes(
+ showspikes=True,
+ spikemode="across",
+ spikesnap="cursor",
+ spikedash="solid",
+ spikecolor="grey",
+ spikethickness=1,
+ showgrid=True,
+ gridcolor=self.LIGHT_GRAY,
+ )
+
+ fig.update_xaxes(
+ nticks=len(self.bars) // 10,
+ ticklen=10,
+ ticks="outside",
+ minor=dict(nticks=5, ticklen=5, ticks="outside"),
+ row=rows,
+ col=1,
+ )
+
+ # 设置K线显示区域
+ if self.width:
+ win_size = int(self.width // 10)
+ else:
+ win_size = 120
+
+ fig.update_xaxes(
+ type="category", range=[len(self.bars) - win_size, len(self.bars) - 1]
+ )
+
+ fig.update_layout(
+ yaxis=dict(range=ylim),
+ hovermode="x unified",
+ plot_bgcolor=self.TRANSPARENT,
+ xaxis_rangeslider_visible=False,
+ )
+
+ if self.width:
+ fig.update_layout(width=self.width)
+
+ if self.height:
+ fig.update_layout(height=self.height)
+
+ return fig
+
+ def _format_tick(self, tm: np.array) -> NDArray:
+ if tm.item(0).hour == 0: # assume it's date
+ return np.array(
+ [
+ f"{x.item().year:02}-{x.item().month:02}-{x.item().day:02}"
+ for x in tm
+ ]
+ )
+ else:
+ return np.array(
+ [
+ f"{x.item().month:02}-{x.item().day:02} {x.item().hour:02}:{x.item().minute:02}"
+ for x in tm
+ ]
+ )
+
+ def _remove_ma(self):
+ traces = {}
+ for name in self.main_traces:
+ if not name.startswith("ma"):
+ traces[name] = self.main_traces[name]
+
+ self.main_traces = traces
+
+ def add_main_trace(self, trace_name: str, **kwargs):
+ """add trace to main plot
+
+ 支持的图例类别有peaks, bbox(bounding-box), bt(回测), support_line, resist_line
+ Args:
+ trace_name : 图例名称
+ **kwargs : 其他参数
+
+ """
+ if trace_name == "peaks":
+ self.mark_peaks_and_valleys(
+ kwargs.get("up_thres", 0.03), kwargs.get("down_thres", -0.03)
+ )
+
+ # 标注矩形框
+ elif trace_name == "bbox":
+ self.add_bounding_box(kwargs.get("boxes"))
+
+ # 回测结果
+ elif trace_name == "bt":
+ self.add_backtest_result(kwargs.get("bt"))
+
+ # 增加直线
+ elif trace_name == "support_line":
+ self.add_line("支撑线", kwargs.get("x"), kwargs.get("y"))
+
+ elif trace_name == "resist_line":
+ self.add_line("压力线", kwargs.get("x"), kwargs.get("y"))
+
+ def add_line(self, trace_name: str, x: List[int], y: List[float]):
+ """在k线图上增加以`x`,`y`表示的一条直线
+
+ Args:
+ trace_name : 图例名称
+ x : x轴坐标,所有的x值都必须属于[0, len(self.bars)]
+ y : y值
+ """
+ line = go.Scatter(x=self.ticks[x], y=y, mode="lines", name=trace_name)
+
+ self.main_traces[trace_name] = line
+
+ def mark_support_resist_lines(
+ self, upthres: float = None, downthres: float = None, use_close=True, win=60
+ ):
+ """在K线图上标注支撑线和压力线
+
+ 在`win`个k线内,找出所有的局部峰谷点,并以最高的两个峰连线生成压力线,以最低的两个谷连线生成支撑线。
+
+ Args:
+ upthres : 用来检测峰谷时使用的阈值,参见`omicron.talib.morph.peaks_and_valleys`
+ downthres : 用来检测峰谷时使用的阈值,参见`omicron.talib.morph.peaks_and_valleys`.
+ use_close : 是否使用收盘价来进行检测。如果为False,则使用high来检测压力线,使用low来检测支撑线.
+ win : 检测局部高低点的窗口.
+ """
+ bars = self.bars[-win:]
+ clipped = len(self.bars) - win
+
+ if use_close:
+ support, resist, x_start = support_resist_lines(
+ self._bar_close, upthres, downthres
+ )
+ x = np.arange(len(bars))[x_start:]
+
+ self.add_main_trace("support_line", x=x + clipped, y=support(x))
+ self.add_main_trace("resist_line", x=x + clipped, y=resist(x))
+
+ else: # 使用"high"和"low"
+ bars = self.bars[-win:]
+ support, _, x_start = support_resist_lines(bars["low"], upthres, downthres)
+ x = np.arange(len(bars))[x_start:]
+ self.add_main_trace("support_line", x=x + clipped, y=support(x))
+
+ _, resist, x_start = support_resist_lines(bars["high"], upthres, downthres)
+ x = np.arange(len(bars))[x_start:]
+ self.add_main_trace("resist_line", x=x + clipped, y=resist(x))
+
+ def mark_bbox(self, min_size: int = 20):
+ """在k线图上检测并标注矩形框
+
+ Args:
+ min_size : 矩形框的最小长度
+
+ """
+ boxes = plateaus(self._bar_close, min_size)
+ self.add_main_trace("bbox", boxes=boxes)
+
+ def mark_backtest_result(self, result: dict):
+ """标记买卖点和回测数据
+
+ TODO:
+ 此方法可能未与backtest返回值同步。此外,在portofolio回测中,不可能在k线图中使用此方法。
+
+ Args:
+ points : 买卖点的坐标。
+ """
+ trades = result.get("trades")
+ assets = result.get("assets")
+
+ x, y, labels = [], [], []
+ hover = []
+ labels_color = defaultdict(list)
+
+ for trade in trades:
+ trade_date = arrow.get(trade["time"]).date()
+ asset = assets.get(trade_date)
+
+ security = trade["security"]
+ price = trade["price"]
+ volume = trade["volume"]
+
+ side = trade["order_side"]
+
+ x.append(self._format_tick(trade_date))
+
+ bar = self.bars[self.bars["frame"] == trade_date]
+ if side == "买入":
+ hover.append(
+ f"总资产:{asset}<br><br>{side}:{security}<br>买入价:{price}<br>股数:{volume}"
+ )
+
+ y.append(bar["high"][0] * 1.1)
+ labels.append("B")
+ labels_color["color"].append(self.RED)
+
+ else:
+ y.append(bar["low"][0] * 0.99)
+
+ hover.append(
+ f"总资产:{asset}<hr><br>{side}:{security}<br>卖出价:{price}<br>股数:{volume}"
+ )
+
+ labels.append("S")
+ labels_color["color"].append(self.GREEN)
+
+ labels_color.append(self.GREEN)
+ # txt.append(f'{side}:{security}<br>卖出价:{price}<br>股数:{volume}')
+
+ trace = go.Scatter(
+ x=x,
+ y=y,
+ mode="text",
+ text=labels,
+ name="backtest",
+ hovertext=hover,
+ textfont=labels_color,
+ )
+
+ self.main_traces["bs"] = trace
+
+ def mark_peaks_and_valleys(
+ self, up_thres: Optional[float] = None, down_thres: Optional[float] = None
+ ):
+ """在K线图上标注峰谷点
+
+ Args:
+ up_thres : 用来检测峰谷时使用的阈值,参见[omicron.talib.morph.peaks_and_valleys][]
+ down_thres : 用来检测峰谷时使用的阈值,参见[omicron.talib.morph.peaks_and_valleys][]
+
+ """
+ bars = self.bars
+
+ flags = peaks_and_valleys(self._bar_close, up_thres, down_thres)
+
+ # 移除首尾的顶底标记,一般情况下它们都不是真正的顶和底。
+ flags[0] = 0
+ flags[-1] = 0
+
+ marker_margin = (max(bars["high"]) - min(bars["low"])) * 0.05
+ ticks_up = self.ticks[flags == 1]
+ y_up = bars["high"][flags == 1] + marker_margin
+ ticks_down = self.ticks[flags == -1]
+ y_down = bars["low"][flags == -1] - marker_margin
+
+ trace = go.Scatter(
+ mode="markers", x=ticks_up, y=y_up, marker_symbol="triangle-down", name="峰"
+ )
+ self.main_traces["peaks"] = trace
+
+ trace = go.Scatter(
+ mode="markers",
+ x=ticks_down,
+ y=y_down,
+ marker_symbol="triangle-up",
+ name="谷",
+ )
+ self.main_traces["valleys"] = trace
+
+ def add_bounding_box(self, boxes: List[Tuple]):
+ """bbox是标记在k线图上某个区间内的矩形框,它以该区间最高价和最低价为上下边。
+
+ Args:
+ boxes: 每个元素(start, width)表示各个bbox的起点和宽度。
+ """
+ for j, box in enumerate(boxes):
+ x, y = [], []
+ i, width = box
+ if len(x):
+ x.append(None)
+ y.append(None)
+
+ group = self.bars[i : i + width]
+
+ mean = np.mean(group["close"])
+ std = 2 * np.std(group["close"])
+
+ # 落在两个标准差以内的实体最上方和最下方值
+ hc = np.max(group[group["close"] < mean + std]["close"])
+ lc = np.min(group[group["close"] > mean - std]["close"])
+
+ ho = np.max(group[group["open"] < mean + std]["open"])
+ lo = np.min(group[group["open"] > mean - std]["open"])
+
+ h = max(hc, ho)
+ low = min(lo, lc)
+
+ x.extend(self.ticks[[i, i + width - 1, i + width - 1, i, i]])
+ y.extend((h, h, low, low, h))
+
+ hover = f"宽度: {width}<br>振幅: {h/low - 1:.2%}"
+ trace = go.Scatter(x=x, y=y, fill="toself", name=f"平台整理{j}", text=hover)
+ self.main_traces[f"bbox-{j}"] = trace
+
+ def add_indicator(self, indicator: str, **kwargs):
+ """向k线图中增加技术指标
+
+ Args:
+ indicator: 当前支持值有'volume', 'rsi', 'bbands'
+ kwargs: 计算某个indicator时,需要的参数。比如计算bbands时,需要传入均线的window
+ """
+ if indicator == "volume":
+ colors = np.repeat(self.RED, len(self.bars))
+ colors[self.bars["close"] <= self.bars["open"]] = self.GREEN
+
+ trace = go.Bar(
+ x=self.ticks,
+ y=self.bars["volume"],
+ showlegend=False,
+ marker={"color": colors},
+ )
+ elif indicator == "rsi":
+ win = kwargs.get("win")
+ rsi = talib.RSI(self._bar_close, win) # type: ignore
+ trace = go.Scatter(x=self.ticks, y=rsi, showlegend=False)
+ elif indicator == "bbands":
+ self._remove_ma()
+ win = kwargs.get("win")
+ for name, ind in zip(
+ ["bbands-high", "bbands-mean", "bbands-low"],
+ talib.BBANDS(self._bar_close, win), # type: ignore
+ ):
+ trace = go.Scatter(x=self.ticks, y=ind, showlegend=True, name=name)
+ self.main_traces[name] = trace
+
+ return
+ else:
+ raise ValueError(f"{indicator} not supported")
+
+ self.ind_traces[indicator] = trace
+
+ def add_marks(
+ self,
+ x: List[int],
+ y: List[float],
+ name: str,
+ marker: str = "cross",
+ color: Optional[str] = None,
+ ):
+ """向k线图中增加标记点"""
+ trace = go.Scatter(
+ x=self.ticks[x],
+ y=y,
+ mode="markers",
+ marker_symbol=marker,
+ marker_color=color,
+ name=name,
+ )
+ self.main_traces[name] = trace
+
+ def plot(self):
+ """绘制图表"""
+ fig = self.figure
+ fig.show()
+
figure
+
+
+ property
+ readonly
+
+
+¶返回一个figure对象
+__init__(self, bars, ma_groups=None, title=None, show_volume=True, show_rsi=True, show_peaks=False, width=None, height=None, **kwargs)
+
+
+ special
+
+
+¶构造函数
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
bars |
+ ndarray |
+ 行情数据 |
+ required | +
ma_groups |
+ List[int] |
+ 均线组参数。比如[5, 10, 20]表明向k线图中添加5, 10, 20日均线。如果不提供,将从数组[5, 10, 20, 30, 60, 120, 250]中取直到与 |
+ None |
+
title |
+ str |
+ k线图的标题 |
+ None |
+
show_volume |
+ + | 是否显示成交量图 |
+ True |
+
show_rsi |
+ + | 是否显示RSI图。缺省显示参数为6的RSI图。 |
+ True |
+
show_peaks |
+ + | 是否标记检测出来的峰跟谷。 |
+ False |
+
width |
+ + | the width in 'px' units of the figure |
+ None |
+
height |
+ + | the height in 'px' units of the figure |
+ None |
+
Keyword arguments:
+Name | +Type | +Description | +
---|---|---|
rsi_win |
+ int |
+ default is 6 |
+
omicron/plotting/candlestick.py
def __init__(
+ self,
+ bars: np.ndarray,
+ ma_groups: List[int] = None,
+ title: str = None,
+ show_volume=True,
+ show_rsi=True,
+ show_peaks=False,
+ width=None,
+ height=None,
+ **kwargs,
+):
+ """构造函数
+
+ Args:
+ bars: 行情数据
+ ma_groups: 均线组参数。比如[5, 10, 20]表明向k线图中添加5, 10, 20日均线。如果不提供,将从数组[5, 10, 20, 30, 60, 120, 250]中取直到与`len(bars) - 5`匹配的参数为止。比如bars长度为30,则将取[5, 10, 20]来绘制均线。
+ title: k线图的标题
+ show_volume: 是否显示成交量图
+ show_rsi: 是否显示RSI图。缺省显示参数为6的RSI图。
+ show_peaks: 是否标记检测出来的峰跟谷。
+ width: the width in 'px' units of the figure
+ height: the height in 'px' units of the figure
+ Keyword Args:
+ rsi_win int: default is 6
+ """
+ self.title = title
+ self.bars = bars
+ self.width = width
+ self.height = height
+
+ # traces for main area
+ self.main_traces = {}
+
+ # traces for indicator area
+ self.ind_traces = {}
+
+ self.ticks = self._format_tick(bars["frame"])
+ self._bar_close = array_math_round(bars["close"], 2).astype(np.float64)
+
+ # for every candlestick, it must contain a candlestick plot
+ cs = go.Candlestick(
+ x=self.ticks,
+ open=bars["open"],
+ high=bars["high"],
+ low=bars["low"],
+ close=self._bar_close,
+ line=dict({"width": 1}),
+ name="K线",
+ **kwargs,
+ )
+
+ # Set line and fill colors
+ cs.increasing.fillcolor = "rgba(255,255,255,0.9)"
+ cs.increasing.line.color = self.RED
+ cs.decreasing.fillcolor = self.GREEN
+ cs.decreasing.line.color = self.GREEN
+
+ self.main_traces["ohlc"] = cs
+
+ if show_volume:
+ self.add_indicator("volume")
+
+ if show_peaks:
+ self.add_main_trace("peaks")
+
+ if show_rsi:
+ self.add_indicator("rsi", win=kwargs.get("rsi_win", 6))
+
+ # 增加均线
+ if ma_groups is None:
+ nbars = len(bars)
+ if nbars < 9:
+ ma_groups = []
+ else:
+ groups = np.array([5, 10, 20, 30, 60, 120, 250])
+ idx = max(np.argwhere(groups < (nbars - 5))).item() + 1
+ ma_groups = groups[:idx]
+
+ for win in ma_groups:
+ name = f"ma{win}"
+ if win > len(bars):
+ continue
+ ma = moving_average(self._bar_close, win)
+ line = go.Scatter(
+ y=ma,
+ x=self.ticks,
+ name=name,
+ line=dict(width=1, color=self.MA_COLORS.get(win)),
+ )
+ self.main_traces[name] = line
+
add_bounding_box(self, boxes)
+
+
+¶bbox是标记在k线图上某个区间内的矩形框,它以该区间最高价和最低价为上下边。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
boxes |
+ List[Tuple] |
+ 每个元素(start, width)表示各个bbox的起点和宽度。 |
+ required | +
omicron/plotting/candlestick.py
def add_bounding_box(self, boxes: List[Tuple]):
+ """bbox是标记在k线图上某个区间内的矩形框,它以该区间最高价和最低价为上下边。
+
+ Args:
+ boxes: 每个元素(start, width)表示各个bbox的起点和宽度。
+ """
+ for j, box in enumerate(boxes):
+ x, y = [], []
+ i, width = box
+ if len(x):
+ x.append(None)
+ y.append(None)
+
+ group = self.bars[i : i + width]
+
+ mean = np.mean(group["close"])
+ std = 2 * np.std(group["close"])
+
+ # 落在两个标准差以内的实体最上方和最下方值
+ hc = np.max(group[group["close"] < mean + std]["close"])
+ lc = np.min(group[group["close"] > mean - std]["close"])
+
+ ho = np.max(group[group["open"] < mean + std]["open"])
+ lo = np.min(group[group["open"] > mean - std]["open"])
+
+ h = max(hc, ho)
+ low = min(lo, lc)
+
+ x.extend(self.ticks[[i, i + width - 1, i + width - 1, i, i]])
+ y.extend((h, h, low, low, h))
+
+ hover = f"宽度: {width}<br>振幅: {h/low - 1:.2%}"
+ trace = go.Scatter(x=x, y=y, fill="toself", name=f"平台整理{j}", text=hover)
+ self.main_traces[f"bbox-{j}"] = trace
+
add_indicator(self, indicator, **kwargs)
+
+
+¶向k线图中增加技术指标
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
indicator |
+ str |
+ 当前支持值有'volume', 'rsi', 'bbands' |
+ required | +
kwargs |
+ + | 计算某个indicator时,需要的参数。比如计算bbands时,需要传入均线的window |
+ {} |
+
omicron/plotting/candlestick.py
def add_indicator(self, indicator: str, **kwargs):
+ """向k线图中增加技术指标
+
+ Args:
+ indicator: 当前支持值有'volume', 'rsi', 'bbands'
+ kwargs: 计算某个indicator时,需要的参数。比如计算bbands时,需要传入均线的window
+ """
+ if indicator == "volume":
+ colors = np.repeat(self.RED, len(self.bars))
+ colors[self.bars["close"] <= self.bars["open"]] = self.GREEN
+
+ trace = go.Bar(
+ x=self.ticks,
+ y=self.bars["volume"],
+ showlegend=False,
+ marker={"color": colors},
+ )
+ elif indicator == "rsi":
+ win = kwargs.get("win")
+ rsi = talib.RSI(self._bar_close, win) # type: ignore
+ trace = go.Scatter(x=self.ticks, y=rsi, showlegend=False)
+ elif indicator == "bbands":
+ self._remove_ma()
+ win = kwargs.get("win")
+ for name, ind in zip(
+ ["bbands-high", "bbands-mean", "bbands-low"],
+ talib.BBANDS(self._bar_close, win), # type: ignore
+ ):
+ trace = go.Scatter(x=self.ticks, y=ind, showlegend=True, name=name)
+ self.main_traces[name] = trace
+
+ return
+ else:
+ raise ValueError(f"{indicator} not supported")
+
+ self.ind_traces[indicator] = trace
+
add_line(self, trace_name, x, y)
+
+
+¶在k线图上增加以x
,y
表示的一条直线
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
trace_name |
+ + | 图例名称 |
+ required | +
x |
+ + | x轴坐标,所有的x值都必须属于[0, len(self.bars)] |
+ required | +
y |
+ + | y值 |
+ required | +
omicron/plotting/candlestick.py
def add_line(self, trace_name: str, x: List[int], y: List[float]):
+ """在k线图上增加以`x`,`y`表示的一条直线
+
+ Args:
+ trace_name : 图例名称
+ x : x轴坐标,所有的x值都必须属于[0, len(self.bars)]
+ y : y值
+ """
+ line = go.Scatter(x=self.ticks[x], y=y, mode="lines", name=trace_name)
+
+ self.main_traces[trace_name] = line
+
add_main_trace(self, trace_name, **kwargs)
+
+
+¶add trace to main plot
+支持的图例类别有peaks, bbox(bounding-box), bt(回测), support_line, resist_line
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
trace_name |
+ + | 图例名称 |
+ required | +
**kwargs |
+ + | 其他参数 |
+ {} |
+
omicron/plotting/candlestick.py
def add_main_trace(self, trace_name: str, **kwargs):
+ """add trace to main plot
+
+ 支持的图例类别有peaks, bbox(bounding-box), bt(回测), support_line, resist_line
+ Args:
+ trace_name : 图例名称
+ **kwargs : 其他参数
+
+ """
+ if trace_name == "peaks":
+ self.mark_peaks_and_valleys(
+ kwargs.get("up_thres", 0.03), kwargs.get("down_thres", -0.03)
+ )
+
+ # 标注矩形框
+ elif trace_name == "bbox":
+ self.add_bounding_box(kwargs.get("boxes"))
+
+ # 回测结果
+ elif trace_name == "bt":
+ self.add_backtest_result(kwargs.get("bt"))
+
+ # 增加直线
+ elif trace_name == "support_line":
+ self.add_line("支撑线", kwargs.get("x"), kwargs.get("y"))
+
+ elif trace_name == "resist_line":
+ self.add_line("压力线", kwargs.get("x"), kwargs.get("y"))
+
add_marks(self, x, y, name, marker='cross', color=None)
+
+
+¶向k线图中增加标记点
+ +omicron/plotting/candlestick.py
def add_marks(
+ self,
+ x: List[int],
+ y: List[float],
+ name: str,
+ marker: str = "cross",
+ color: Optional[str] = None,
+):
+ """向k线图中增加标记点"""
+ trace = go.Scatter(
+ x=self.ticks[x],
+ y=y,
+ mode="markers",
+ marker_symbol=marker,
+ marker_color=color,
+ name=name,
+ )
+ self.main_traces[name] = trace
+
mark_backtest_result(self, result)
+
+
+¶标记买卖点和回测数据
+Todo
+此方法可能未与backtest返回值同步。此外,在portofolio回测中,不可能在k线图中使用此方法。
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
points |
+ + | 买卖点的坐标。 |
+ required | +
omicron/plotting/candlestick.py
def mark_backtest_result(self, result: dict):
+ """标记买卖点和回测数据
+
+ TODO:
+ 此方法可能未与backtest返回值同步。此外,在portofolio回测中,不可能在k线图中使用此方法。
+
+ Args:
+ points : 买卖点的坐标。
+ """
+ trades = result.get("trades")
+ assets = result.get("assets")
+
+ x, y, labels = [], [], []
+ hover = []
+ labels_color = defaultdict(list)
+
+ for trade in trades:
+ trade_date = arrow.get(trade["time"]).date()
+ asset = assets.get(trade_date)
+
+ security = trade["security"]
+ price = trade["price"]
+ volume = trade["volume"]
+
+ side = trade["order_side"]
+
+ x.append(self._format_tick(trade_date))
+
+ bar = self.bars[self.bars["frame"] == trade_date]
+ if side == "买入":
+ hover.append(
+ f"总资产:{asset}<br><br>{side}:{security}<br>买入价:{price}<br>股数:{volume}"
+ )
+
+ y.append(bar["high"][0] * 1.1)
+ labels.append("B")
+ labels_color["color"].append(self.RED)
+
+ else:
+ y.append(bar["low"][0] * 0.99)
+
+ hover.append(
+ f"总资产:{asset}<hr><br>{side}:{security}<br>卖出价:{price}<br>股数:{volume}"
+ )
+
+ labels.append("S")
+ labels_color["color"].append(self.GREEN)
+
+ labels_color.append(self.GREEN)
+ # txt.append(f'{side}:{security}<br>卖出价:{price}<br>股数:{volume}')
+
+ trace = go.Scatter(
+ x=x,
+ y=y,
+ mode="text",
+ text=labels,
+ name="backtest",
+ hovertext=hover,
+ textfont=labels_color,
+ )
+
+ self.main_traces["bs"] = trace
+
mark_bbox(self, min_size=20)
+
+
+¶在k线图上检测并标注矩形框
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
min_size |
+ + | 矩形框的最小长度 |
+ 20 |
+
omicron/plotting/candlestick.py
def mark_bbox(self, min_size: int = 20):
+ """在k线图上检测并标注矩形框
+
+ Args:
+ min_size : 矩形框的最小长度
+
+ """
+ boxes = plateaus(self._bar_close, min_size)
+ self.add_main_trace("bbox", boxes=boxes)
+
mark_peaks_and_valleys(self, up_thres=None, down_thres=None)
+
+
+¶在K线图上标注峰谷点
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
up_thres |
+ + | 用来检测峰谷时使用的阈值,参见omicron.talib.morph.peaks_and_valleys |
+ None |
+
down_thres |
+ + | 用来检测峰谷时使用的阈值,参见omicron.talib.morph.peaks_and_valleys |
+ None |
+
omicron/plotting/candlestick.py
def mark_peaks_and_valleys(
+ self, up_thres: Optional[float] = None, down_thres: Optional[float] = None
+):
+ """在K线图上标注峰谷点
+
+ Args:
+ up_thres : 用来检测峰谷时使用的阈值,参见[omicron.talib.morph.peaks_and_valleys][]
+ down_thres : 用来检测峰谷时使用的阈值,参见[omicron.talib.morph.peaks_and_valleys][]
+
+ """
+ bars = self.bars
+
+ flags = peaks_and_valleys(self._bar_close, up_thres, down_thres)
+
+ # 移除首尾的顶底标记,一般情况下它们都不是真正的顶和底。
+ flags[0] = 0
+ flags[-1] = 0
+
+ marker_margin = (max(bars["high"]) - min(bars["low"])) * 0.05
+ ticks_up = self.ticks[flags == 1]
+ y_up = bars["high"][flags == 1] + marker_margin
+ ticks_down = self.ticks[flags == -1]
+ y_down = bars["low"][flags == -1] - marker_margin
+
+ trace = go.Scatter(
+ mode="markers", x=ticks_up, y=y_up, marker_symbol="triangle-down", name="峰"
+ )
+ self.main_traces["peaks"] = trace
+
+ trace = go.Scatter(
+ mode="markers",
+ x=ticks_down,
+ y=y_down,
+ marker_symbol="triangle-up",
+ name="谷",
+ )
+ self.main_traces["valleys"] = trace
+
mark_support_resist_lines(self, upthres=None, downthres=None, use_close=True, win=60)
+
+
+¶在K线图上标注支撑线和压力线
+在win
个k线内,找出所有的局部峰谷点,并以最高的两个峰连线生成压力线,以最低的两个谷连线生成支撑线。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
upthres |
+ + | 用来检测峰谷时使用的阈值,参见 |
+ None |
+
downthres |
+ + | 用来检测峰谷时使用的阈值,参见 |
+ None |
+
use_close |
+ + | 是否使用收盘价来进行检测。如果为False,则使用high来检测压力线,使用low来检测支撑线. |
+ True |
+
win |
+ + | 检测局部高低点的窗口. |
+ 60 |
+
omicron/plotting/candlestick.py
def mark_support_resist_lines(
+ self, upthres: float = None, downthres: float = None, use_close=True, win=60
+):
+ """在K线图上标注支撑线和压力线
+
+ 在`win`个k线内,找出所有的局部峰谷点,并以最高的两个峰连线生成压力线,以最低的两个谷连线生成支撑线。
+
+ Args:
+ upthres : 用来检测峰谷时使用的阈值,参见`omicron.talib.morph.peaks_and_valleys`
+ downthres : 用来检测峰谷时使用的阈值,参见`omicron.talib.morph.peaks_and_valleys`.
+ use_close : 是否使用收盘价来进行检测。如果为False,则使用high来检测压力线,使用low来检测支撑线.
+ win : 检测局部高低点的窗口.
+ """
+ bars = self.bars[-win:]
+ clipped = len(self.bars) - win
+
+ if use_close:
+ support, resist, x_start = support_resist_lines(
+ self._bar_close, upthres, downthres
+ )
+ x = np.arange(len(bars))[x_start:]
+
+ self.add_main_trace("support_line", x=x + clipped, y=support(x))
+ self.add_main_trace("resist_line", x=x + clipped, y=resist(x))
+
+ else: # 使用"high"和"low"
+ bars = self.bars[-win:]
+ support, _, x_start = support_resist_lines(bars["low"], upthres, downthres)
+ x = np.arange(len(bars))[x_start:]
+ self.add_main_trace("support_line", x=x + clipped, y=support(x))
+
+ _, resist, x_start = support_resist_lines(bars["high"], upthres, downthres)
+ x = np.arange(len(bars))[x_start:]
+ self.add_main_trace("resist_line", x=x + clipped, y=resist(x))
+
plot(self)
+
+
+¶绘制图表
+ +omicron/plotting/candlestick.py
def plot(self):
+ """绘制图表"""
+ fig = self.figure
+ fig.show()
+
绘制回测资产曲线和指标图。
+示例: +
1 +2 +3 +4 +5 +6 |
|
+MetricsGraph
+
+
+
+¶omicron/plotting/metrics.py
class MetricsGraph:
+ def __init__(
+ self,
+ bills: dict,
+ metrics: dict,
+ baseline_code: str = "399300.XSHE",
+ indicator: Optional[pd.DataFrame] = None,
+ ):
+ """
+ Args:
+ bills: 回测生成的账单,通过Strategy.bills获得
+ metrics: 回测生成的指标,通过strategy.metrics获得
+ baseline_code: 基准证券代码
+ indicator: 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标值列名为"value"的pandas.DataFrame。如果不提供,将不会绘制指标图
+ """
+ self.metrics = metrics
+ self.trades = bills["trades"]
+ self.positions = bills["positions"]
+ self.start = arrow.get(bills["assets"][0][0]).date()
+ self.end = arrow.get(bills["assets"][-1][0]).date()
+
+ self.frames = [
+ tf.int2date(f) for f in tf.get_frames(self.start, self.end, FrameType.DAY)
+ ]
+
+ if indicator is not None:
+ self.indicator = indicator.join(
+ pd.Series(index=self.frames, name="frames", dtype=np.float64),
+ how="right",
+ )
+ else:
+ self.indicator = None
+
+ # 记录日期到下标的反向映射
+ self._frame2pos = {f: i for i, f in enumerate(self.frames)}
+ self.ticks = self._format_tick(self.frames)
+
+ # TODO: there's bug in backtesting, temporarily fix here
+ df = pd.DataFrame(self.frames, columns=["frame"])
+ df["assets"] = np.nan
+ assets = pd.DataFrame(bills["assets"], columns=["frame", "assets"])
+ df["assets"] = assets["assets"]
+ self.assets = df.fillna(method="ffill")["assets"].to_numpy()
+ self.nv = self.assets / self.assets[0]
+
+ self.baseline_code = baseline_code or "399300.XSHE"
+
+ def _fill_missing_prices(self, bars: BarsArray, frames: Union[List, NDArray]):
+ """将bars中缺失值采用其前值替换
+
+ 当baseline为个股时,可能存在停牌的情况,这样导致由此计算的参考收益无法与回测的资产收益对齐,因此需要进行调整。
+
+ 出于这个目的,本函数只返回处理后的收盘价。
+
+ Args:
+ bars: 基线行情数据。
+ frames: 日期索引
+
+ Returns:
+ 补充缺失值后的收盘价序列
+ """
+ _close = pd.DataFrame(
+ {
+ "close": pd.Series(bars["close"], index=bars["frame"]),
+ "frame": pd.Series(np.empty((len(frames),)), index=frames),
+ }
+ )["close"].to_numpy()
+
+ # 这里使用omicron中的fill_nan,是因为如果数组的第一个元素即为NaN的话,那么DataFrame.fillna(method='ffill')将无法处理这样的情况(仍然保持为nan)
+
+ return fill_nan(_close)
+
+ def _format_tick(self, frames: Union[Frame, List[Frame]]) -> Union[str, NDArray]:
+ if type(frames) == datetime.date:
+ x = frames
+ return f"{x.year:02}-{x.month:02}-{x.day:02}"
+ elif type(frames) == datetime.datetime:
+ x = frames
+ return f"{x.month:02}-{x.day:02} {x.hour:02}:{x.minute:02}"
+ elif type(frames[0]) == datetime.date: # type: ignore
+ return np.array([f"{x.year:02}-{x.month:02}-{x.day:02}" for x in frames])
+ else:
+ return np.array(
+ [f"{x.month:02}-{x.day:02} {x.hour:02}:{x.minute:02}" for x in frames] # type: ignore
+ )
+
+ async def _metrics_trace(self):
+ metric_names = {
+ "start": "起始日",
+ "end": "结束日",
+ "window": "资产暴露窗口",
+ "total_tx": "交易次数",
+ "total_profit": "总利润",
+ "total_profit_rate": "利润率",
+ "win_rate": "胜率",
+ "mean_return": "日均回报",
+ "sharpe": "夏普率",
+ "max_drawdown": "最大回撤",
+ "annual_return": "年化回报",
+ "volatility": "波动率",
+ "sortino": "sortino",
+ "calmar": "calmar",
+ }
+
+ # bug: plotly go.Table.Cells format not work here
+ metric_formatter = {
+ "start": "{}",
+ "end": "{}",
+ "window": "{}",
+ "total_tx": "{}",
+ "total_profit": "{:.2f}",
+ "total_profit_rate": "{:.2%}",
+ "win_rate": "{:.2%}",
+ "mean_return": "{:.2%}",
+ "sharpe": "{:.2f}",
+ "max_drawdown": "{:.2%}",
+ "annual_return": "{:.2%}",
+ "volatility": "{:.2%}",
+ "sortino": "{:.2f}",
+ "calmar": "{:.2f}",
+ }
+
+ metrics = deepcopy(self.metrics)
+ baseline = metrics["baseline"] or {}
+ del metrics["baseline"]
+
+ baseline_name = (
+ await Security.alias(self.baseline_code) if self.baseline_code else "基准"
+ )
+
+ metrics_formatted = []
+ for k in metric_names.keys():
+ if metrics.get(k):
+ metrics_formatted.append(metric_formatter[k].format(metrics.get(k)))
+ else:
+ metrics_formatted.append("-")
+
+ baseline_formatted = []
+ for k in metric_names.keys():
+ if baseline.get(k):
+ baseline_formatted.append(metric_formatter[k].format(baseline.get(k)))
+ else:
+ baseline_formatted.append("-")
+
+ return go.Table(
+ header=dict(values=["指标名", "策略", baseline_name]),
+ cells=dict(
+ values=[
+ [v for _, v in metric_names.items()],
+ metrics_formatted,
+ baseline_formatted,
+ ],
+ font_size=10,
+ ),
+ )
+
+ async def _trade_info_trace(self):
+ """构建hover text 序列"""
+ # convert trades into hover_info
+ buys = defaultdict(list)
+ sells = defaultdict(list)
+ for _, trade in self.trades.items():
+ trade_date = arrow.get(trade["time"]).date()
+
+ ipos = self._frame2pos.get(trade_date)
+ if ipos is None:
+ logger.warning(
+ "date %s in trade record not in backtest range", trade_date
+ )
+ continue
+
+ name = await Security.alias(trade["security"])
+ price = trade["price"]
+ side = trade["order_side"]
+ filled = trade["filled"]
+
+ trade_text = f"{side}:{name} {filled/100:.0f}手 价格:{price:.02f} 成交额:{filled * price/10000:.1f}万"
+
+ if side == "卖出":
+ sells[trade_date].append(trade_text)
+ elif side in ("买入", "分红配股"):
+ buys[trade_date].append(trade_text)
+
+ X_buy, Y_buy, data_buy = [], [], []
+ X_sell, Y_sell, data_sell = [], [], []
+
+ for dt, text in buys.items():
+ ipos = self._frame2pos.get(dt)
+ Y_buy.append(self.nv[ipos])
+ X_buy.append(self._format_tick(dt))
+
+ asset = self.assets[ipos]
+ hover = f"资产:{asset/10000:.1f}万<br>{'<br>'.join(text)}"
+ data_buy.append(hover)
+
+ trace_buy = go.Scatter(
+ x=X_buy,
+ y=Y_buy,
+ mode="markers",
+ text=data_buy,
+ name="买入成交",
+ marker=dict(color="red", symbol="triangle-up"),
+ hovertemplate="<br>%{text}",
+ )
+
+ for dt, text in sells.items():
+ ipos = self._frame2pos.get(dt)
+ Y_sell.append(self.nv[ipos])
+ X_sell.append(self._format_tick(dt))
+
+ asset = self.assets[ipos]
+ hover = f"资产:{asset/10000:.1f}万<br>{'<br>'.join(text)}"
+ data_sell.append(hover)
+
+ trace_sell = go.Scatter(
+ x=X_sell,
+ y=Y_sell,
+ mode="markers",
+ text=data_sell,
+ name="卖出成交",
+ marker=dict(color="green", symbol="triangle-down"),
+ hovertemplate="<br>%{text}",
+ )
+
+ return trace_buy, trace_sell
+
+ async def plot(self):
+ """绘制资产曲线及回测指标图"""
+ n = len(self.assets)
+ bars = await Stock.get_bars(self.baseline_code, n, FrameType.DAY, self.end)
+
+ baseline_prices = self._fill_missing_prices(bars, self.frames)
+ baseline_prices /= baseline_prices[0]
+
+ fig = make_subplots(
+ rows=1,
+ cols=2,
+ shared_xaxes=False,
+ specs=[
+ [{"secondary_y": True}, {"type": "table"}],
+ ],
+ column_width=[0.75, 0.25],
+ horizontal_spacing=0.01,
+ subplot_titles=("资产曲线", "策略指标"),
+ )
+
+ fig.add_trace(await self._metrics_trace(), row=1, col=2)
+
+ if self.indicator is not None:
+ indicator_on_hover = self.indicator["value"]
+ else:
+ indicator_on_hover = None
+
+ baseline_name = (
+ await Security.alias(self.baseline_code) if self.baseline_code else "基准"
+ )
+
+ baseline_trace = go.Scatter(
+ y=baseline_prices,
+ x=self.ticks,
+ mode="lines",
+ name=baseline_name,
+ showlegend=True,
+ text=indicator_on_hover,
+ hovertemplate="<br>净值:%{y:.2f}" + "<br>指标:%{text:.1f}",
+ )
+ fig.add_trace(baseline_trace, row=1, col=1)
+
+ nv_trace = go.Scatter(
+ y=self.nv,
+ x=self.ticks,
+ mode="lines",
+ name="策略",
+ showlegend=True,
+ hovertemplate="<br>净值:%{y:.2f}",
+ )
+ fig.add_trace(nv_trace, row=1, col=1)
+
+ if self.indicator is not None:
+ ind_trace = go.Scatter(
+ y=self.indicator["value"],
+ x=self.ticks,
+ mode="lines",
+ name="indicator",
+ showlegend=True,
+ visible="legendonly",
+ )
+ fig.add_trace(ind_trace, row=1, col=1, secondary_y=True)
+
+ for trace in await self._trade_info_trace():
+ fig.add_trace(trace, row=1, col=1)
+
+ fig.update_xaxes(type="category", tickangle=45, nticks=len(self.ticks) // 5)
+ fig.update_layout(margin=dict(l=20, r=20, t=50, b=50), width=1040, height=435)
+ fig.update_layout(
+ hovermode="x unified", hoverlabel=dict(bgcolor="rgba(255,255,255,0.8)")
+ )
+ fig.show()
+
__init__(self, bills, metrics, baseline_code='399300.XSHE', indicator=None)
+
+
+ special
+
+
+¶Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
bills |
+ dict |
+ 回测生成的账单,通过Strategy.bills获得 |
+ required | +
metrics |
+ dict |
+ 回测生成的指标,通过strategy.metrics获得 |
+ required | +
baseline_code |
+ str |
+ 基准证券代码 |
+ '399300.XSHE' |
+
indicator |
+ Optional[pandas.core.frame.DataFrame] |
+ 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标值列名为"value"的pandas.DataFrame。如果不提供,将不会绘制指标图 |
+ None |
+
omicron/plotting/metrics.py
def __init__(
+ self,
+ bills: dict,
+ metrics: dict,
+ baseline_code: str = "399300.XSHE",
+ indicator: Optional[pd.DataFrame] = None,
+):
+ """
+ Args:
+ bills: 回测生成的账单,通过Strategy.bills获得
+ metrics: 回测生成的指标,通过strategy.metrics获得
+ baseline_code: 基准证券代码
+ indicator: 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标值列名为"value"的pandas.DataFrame。如果不提供,将不会绘制指标图
+ """
+ self.metrics = metrics
+ self.trades = bills["trades"]
+ self.positions = bills["positions"]
+ self.start = arrow.get(bills["assets"][0][0]).date()
+ self.end = arrow.get(bills["assets"][-1][0]).date()
+
+ self.frames = [
+ tf.int2date(f) for f in tf.get_frames(self.start, self.end, FrameType.DAY)
+ ]
+
+ if indicator is not None:
+ self.indicator = indicator.join(
+ pd.Series(index=self.frames, name="frames", dtype=np.float64),
+ how="right",
+ )
+ else:
+ self.indicator = None
+
+ # 记录日期到下标的反向映射
+ self._frame2pos = {f: i for i, f in enumerate(self.frames)}
+ self.ticks = self._format_tick(self.frames)
+
+ # TODO: there's bug in backtesting, temporarily fix here
+ df = pd.DataFrame(self.frames, columns=["frame"])
+ df["assets"] = np.nan
+ assets = pd.DataFrame(bills["assets"], columns=["frame", "assets"])
+ df["assets"] = assets["assets"]
+ self.assets = df.fillna(method="ffill")["assets"].to_numpy()
+ self.nv = self.assets / self.assets[0]
+
+ self.baseline_code = baseline_code or "399300.XSHE"
+
plot(self)
+
+
+ async
+
+
+¶绘制资产曲线及回测指标图
+ +omicron/plotting/metrics.py
async def plot(self):
+ """绘制资产曲线及回测指标图"""
+ n = len(self.assets)
+ bars = await Stock.get_bars(self.baseline_code, n, FrameType.DAY, self.end)
+
+ baseline_prices = self._fill_missing_prices(bars, self.frames)
+ baseline_prices /= baseline_prices[0]
+
+ fig = make_subplots(
+ rows=1,
+ cols=2,
+ shared_xaxes=False,
+ specs=[
+ [{"secondary_y": True}, {"type": "table"}],
+ ],
+ column_width=[0.75, 0.25],
+ horizontal_spacing=0.01,
+ subplot_titles=("资产曲线", "策略指标"),
+ )
+
+ fig.add_trace(await self._metrics_trace(), row=1, col=2)
+
+ if self.indicator is not None:
+ indicator_on_hover = self.indicator["value"]
+ else:
+ indicator_on_hover = None
+
+ baseline_name = (
+ await Security.alias(self.baseline_code) if self.baseline_code else "基准"
+ )
+
+ baseline_trace = go.Scatter(
+ y=baseline_prices,
+ x=self.ticks,
+ mode="lines",
+ name=baseline_name,
+ showlegend=True,
+ text=indicator_on_hover,
+ hovertemplate="<br>净值:%{y:.2f}" + "<br>指标:%{text:.1f}",
+ )
+ fig.add_trace(baseline_trace, row=1, col=1)
+
+ nv_trace = go.Scatter(
+ y=self.nv,
+ x=self.ticks,
+ mode="lines",
+ name="策略",
+ showlegend=True,
+ hovertemplate="<br>净值:%{y:.2f}",
+ )
+ fig.add_trace(nv_trace, row=1, col=1)
+
+ if self.indicator is not None:
+ ind_trace = go.Scatter(
+ y=self.indicator["value"],
+ x=self.ticks,
+ mode="lines",
+ name="indicator",
+ showlegend=True,
+ visible="legendonly",
+ )
+ fig.add_trace(ind_trace, row=1, col=1, secondary_y=True)
+
+ for trace in await self._trade_info_trace():
+ fig.add_trace(trace, row=1, col=1)
+
+ fig.update_xaxes(type="category", tickangle=45, nticks=len(self.ticks) // 5)
+ fig.update_layout(margin=dict(l=20, r=20, t=50, b=50), width=1040, height=435)
+ fig.update_layout(
+ hovermode="x unified", hoverlabel=dict(bgcolor="rgba(255,255,255,0.8)")
+ )
+ fig.show()
+
+Query
+
+
+
+¶证券信息查询对象
+证券信息查询对象,由Security.select()
方法生成,支持链式查询。通过eval
函数结束链式调用并生成查询结果。
omicron/models/security.py
class Query:
+ """证券信息查询对象
+
+ 证券信息查询对象,由`Security.select()`方法生成,支持链式查询。通过`eval`函数结束链式调用并生成查询结果。
+ """
+
+ def __init__(self, target_date: datetime.date = None):
+ if target_date is None:
+ # 聚宽不一定会及时更新数据,因此db中不存放当天的数据,如果传空,查cache
+ self.target_date = None
+ else:
+ # 如果是交易日,取当天,否则取前一天
+ self.target_date = tf.day_shift(target_date, 0)
+
+ # 名字,显示名,类型过滤器
+ self._name_pattern = None # 字母名字
+ self._alias_pattern = None # 显示名
+ self._type_pattern = None # 不指定则默认为全部,如果传入空值则只选择股票和指数
+ # 开关选项
+ self._exclude_kcb = False # 科创板
+ self._exclude_cyb = False # 创业板
+ self._exclude_st = False # ST
+ self._include_exit = False # 是否包含已退市证券(默认不包括当天退市的)
+ # 下列开关优先级高于上面的
+ self._only_kcb = False
+ self._only_cyb = False
+ self._only_st = False
+
+ def only_cyb(self) -> "Query":
+ """返回结果中只包含创业板股票"""
+ self._only_cyb = True # 高优先级
+ self._exclude_cyb = False
+ self._only_kcb = False
+ self._only_st = False
+ return self
+
+ def only_st(self) -> "Query":
+ """返回结果中只包含ST类型的证券"""
+ self._only_st = True # 高优先级
+ self._exclude_st = False
+ self._only_kcb = False
+ self._only_cyb = False
+ return self
+
+ def only_kcb(self) -> "Query":
+ """返回结果中只包含科创板股票"""
+ self._only_kcb = True # 高优先级
+ self._exclude_kcb = False
+ self._only_cyb = False
+ self._only_st = False
+ return self
+
+ def exclude_st(self) -> "Query":
+ """从返回结果中排除ST类型的股票"""
+ self._exclude_st = True
+ self._only_st = False
+ return self
+
+ def exclude_cyb(self) -> "Query":
+ """从返回结果中排除创业板类型的股票"""
+ self._exclude_cyb = True
+ self._only_cyb = False
+ return self
+
+ def exclude_kcb(self) -> "Query":
+ """从返回结果中排除科创板类型的股票"""
+ self._exclude_kcb = True
+ self._only_kcb = False
+ return self
+
+ def include_exit(self) -> "Query":
+ """从返回结果中包含已退市的证券"""
+ self._include_exit = True
+ return self
+
+ def types(self, types: List[str]) -> "Query":
+ """选择类型在`types`中的证券品种
+
+ 如果不调用此方法,默认选择所有股票类型。
+ 如果调用此方法但不传入参数,默认选择指数+股票
+ Args:
+ types: 有效的类型包括: 对股票指数而言是('index', 'stock'),对基金而言则是('etf', 'fjb', 'mmf', 'reits', 'fja', 'fjm', 'lof')
+ """
+ if types is None or isinstance(types, List) is False:
+ return self
+
+ if len(types) == 0:
+ self._type_pattern = ["index", "stock"]
+ else:
+ tmp = set(types)
+ self._type_pattern = list(tmp)
+
+ return self
+
+ def name_like(self, name: str) -> "Query":
+ """查找股票/证券名称中出现`name`的品种
+
+ 注意这里的证券名称并不是其显示名。比如对中国平安000001.XSHE来说,它的名称是ZGPA,而不是“中国平安”。
+
+ Args:
+ name: 待查找的名字,比如"ZGPA"
+
+ """
+ if name is None or len(name) == 0:
+ self._name_pattern = None
+ else:
+ self._name_pattern = name
+
+ return self
+
+ def alias_like(self, display_name: str) -> "Query":
+ """查找股票/证券显示名中出现`display_name的品种
+
+ Args:
+ display_name: 显示名,比如“中国平安"
+ """
+ if display_name is None or len(display_name) == 0:
+ self._alias_pattern = None
+ else:
+ self._alias_pattern = display_name
+
+ return self
+
+ async def eval(self) -> List[str]:
+ """对查询结果进行求值,返回code列表
+
+ Returns:
+ 代码列表
+ """
+ logger.debug("eval, date: %s", self.target_date)
+ logger.debug(
+ "eval, names and types: %s, %s, %s",
+ self._name_pattern,
+ self._alias_pattern,
+ self._type_pattern,
+ )
+ logger.debug(
+ "eval, exclude and include: %s, %s, %s, %s",
+ self._exclude_cyb,
+ self._exclude_st,
+ self._exclude_kcb,
+ self._include_exit,
+ )
+ logger.debug(
+ "eval, only: %s, %s, %s ", self._only_cyb, self._only_st, self._only_kcb
+ )
+
+ date_in_cache = await cache.security.get("security:latest_date")
+ if date_in_cache: # 无此数据说明omega有某些问题,不处理
+ _date = arrow.get(date_in_cache).date()
+ else:
+ now = datetime.datetime.now()
+ _date = tf.day_shift(now, 0)
+
+ # 确定数据源,cache为当天8点之后获取的数据,数据库存放前一日和更早的数据
+ if not self.target_date or self.target_date >= _date:
+ self.target_date = _date
+
+ records = None
+ if self.target_date == _date: # 从内存中查找,如果缓存中的数据已更新,重新加载到内存
+ secs = await cache.security.lrange("security:all", 0, -1)
+ if len(secs) != 0:
+ # using np.datetime64[s]
+ records = np.array(
+ [tuple(x.split(",")) for x in secs], dtype=security_info_dtype
+ )
+ else:
+ records = await Security.load_securities_from_db(self.target_date)
+ if records is None:
+ return None
+
+ results = []
+ self._type_pattern = self._type_pattern or SecurityType.STOCK.value
+ for record in records:
+ if self._type_pattern is not None:
+ if record["type"] not in self._type_pattern:
+ continue
+ if self._name_pattern is not None:
+ if record["name"].find(self._name_pattern) == -1:
+ continue
+ if self._alias_pattern is not None:
+ if record["alias"].find(self._alias_pattern) == -1:
+ continue
+
+ # 创业板,科创板,ST暂时限定为股票类型
+ if self._only_cyb:
+ if record["type"] != SecurityType.STOCK.value or not (
+ record["code"][:3] in ("300", "301")
+ ):
+ continue
+ if self._only_kcb:
+ if (
+ record["type"] != SecurityType.STOCK.value
+ or record["code"].startswith("688") is False
+ ):
+ continue
+ if self._only_st:
+ if (
+ record["type"] != SecurityType.STOCK.value
+ or record["alias"].find("ST") == -1
+ ):
+ continue
+ if self._exclude_cyb:
+ if record["type"] == SecurityType.STOCK.value and record["code"][
+ :3
+ ] in ("300", "301"):
+ continue
+ if self._exclude_st:
+ if (
+ record["type"] == SecurityType.STOCK.value
+ and record["alias"].find("ST") != -1
+ ):
+ continue
+ if self._exclude_kcb:
+ if record["type"] == SecurityType.STOCK.value and record[
+ "code"
+ ].startswith("688"):
+ continue
+
+ # 退市暂不限定是否为股票
+ if self._include_exit is False:
+ d1 = convert_nptime_to_datetime(record["end"]).date()
+ if d1 < self.target_date:
+ continue
+
+ results.append(record["code"])
+
+ # 返回所有查询到的结果
+ return results
+
alias_like(self, display_name)
+
+
+¶查找股票/证券显示名中出现`display_name的品种
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
display_name |
+ str |
+ 显示名,比如“中国平安" |
+ required | +
omicron/models/security.py
def alias_like(self, display_name: str) -> "Query":
+ """查找股票/证券显示名中出现`display_name的品种
+
+ Args:
+ display_name: 显示名,比如“中国平安"
+ """
+ if display_name is None or len(display_name) == 0:
+ self._alias_pattern = None
+ else:
+ self._alias_pattern = display_name
+
+ return self
+
eval(self)
+
+
+ async
+
+
+¶对查询结果进行求值,返回code列表
+ +Returns:
+Type | +Description | +
---|---|
List[str] |
+ 代码列表 |
+
omicron/models/security.py
async def eval(self) -> List[str]:
+ """对查询结果进行求值,返回code列表
+
+ Returns:
+ 代码列表
+ """
+ logger.debug("eval, date: %s", self.target_date)
+ logger.debug(
+ "eval, names and types: %s, %s, %s",
+ self._name_pattern,
+ self._alias_pattern,
+ self._type_pattern,
+ )
+ logger.debug(
+ "eval, exclude and include: %s, %s, %s, %s",
+ self._exclude_cyb,
+ self._exclude_st,
+ self._exclude_kcb,
+ self._include_exit,
+ )
+ logger.debug(
+ "eval, only: %s, %s, %s ", self._only_cyb, self._only_st, self._only_kcb
+ )
+
+ date_in_cache = await cache.security.get("security:latest_date")
+ if date_in_cache: # 无此数据说明omega有某些问题,不处理
+ _date = arrow.get(date_in_cache).date()
+ else:
+ now = datetime.datetime.now()
+ _date = tf.day_shift(now, 0)
+
+ # 确定数据源,cache为当天8点之后获取的数据,数据库存放前一日和更早的数据
+ if not self.target_date or self.target_date >= _date:
+ self.target_date = _date
+
+ records = None
+ if self.target_date == _date: # 从内存中查找,如果缓存中的数据已更新,重新加载到内存
+ secs = await cache.security.lrange("security:all", 0, -1)
+ if len(secs) != 0:
+ # using np.datetime64[s]
+ records = np.array(
+ [tuple(x.split(",")) for x in secs], dtype=security_info_dtype
+ )
+ else:
+ records = await Security.load_securities_from_db(self.target_date)
+ if records is None:
+ return None
+
+ results = []
+ self._type_pattern = self._type_pattern or SecurityType.STOCK.value
+ for record in records:
+ if self._type_pattern is not None:
+ if record["type"] not in self._type_pattern:
+ continue
+ if self._name_pattern is not None:
+ if record["name"].find(self._name_pattern) == -1:
+ continue
+ if self._alias_pattern is not None:
+ if record["alias"].find(self._alias_pattern) == -1:
+ continue
+
+ # 创业板,科创板,ST暂时限定为股票类型
+ if self._only_cyb:
+ if record["type"] != SecurityType.STOCK.value or not (
+ record["code"][:3] in ("300", "301")
+ ):
+ continue
+ if self._only_kcb:
+ if (
+ record["type"] != SecurityType.STOCK.value
+ or record["code"].startswith("688") is False
+ ):
+ continue
+ if self._only_st:
+ if (
+ record["type"] != SecurityType.STOCK.value
+ or record["alias"].find("ST") == -1
+ ):
+ continue
+ if self._exclude_cyb:
+ if record["type"] == SecurityType.STOCK.value and record["code"][
+ :3
+ ] in ("300", "301"):
+ continue
+ if self._exclude_st:
+ if (
+ record["type"] == SecurityType.STOCK.value
+ and record["alias"].find("ST") != -1
+ ):
+ continue
+ if self._exclude_kcb:
+ if record["type"] == SecurityType.STOCK.value and record[
+ "code"
+ ].startswith("688"):
+ continue
+
+ # 退市暂不限定是否为股票
+ if self._include_exit is False:
+ d1 = convert_nptime_to_datetime(record["end"]).date()
+ if d1 < self.target_date:
+ continue
+
+ results.append(record["code"])
+
+ # 返回所有查询到的结果
+ return results
+
exclude_cyb(self)
+
+
+¶从返回结果中排除创业板类型的股票
+ +omicron/models/security.py
def exclude_cyb(self) -> "Query":
+ """从返回结果中排除创业板类型的股票"""
+ self._exclude_cyb = True
+ self._only_cyb = False
+ return self
+
exclude_kcb(self)
+
+
+¶从返回结果中排除科创板类型的股票
+ +omicron/models/security.py
def exclude_kcb(self) -> "Query":
+ """从返回结果中排除科创板类型的股票"""
+ self._exclude_kcb = True
+ self._only_kcb = False
+ return self
+
exclude_st(self)
+
+
+¶从返回结果中排除ST类型的股票
+ +omicron/models/security.py
def exclude_st(self) -> "Query":
+ """从返回结果中排除ST类型的股票"""
+ self._exclude_st = True
+ self._only_st = False
+ return self
+
include_exit(self)
+
+
+¶从返回结果中包含已退市的证券
+ +omicron/models/security.py
def include_exit(self) -> "Query":
+ """从返回结果中包含已退市的证券"""
+ self._include_exit = True
+ return self
+
name_like(self, name)
+
+
+¶查找股票/证券名称中出现name
的品种
注意这里的证券名称并不是其显示名。比如对中国平安000001.XSHE来说,它的名称是ZGPA,而不是“中国平安”。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
name |
+ str |
+ 待查找的名字,比如"ZGPA" |
+ required | +
omicron/models/security.py
def name_like(self, name: str) -> "Query":
+ """查找股票/证券名称中出现`name`的品种
+
+ 注意这里的证券名称并不是其显示名。比如对中国平安000001.XSHE来说,它的名称是ZGPA,而不是“中国平安”。
+
+ Args:
+ name: 待查找的名字,比如"ZGPA"
+
+ """
+ if name is None or len(name) == 0:
+ self._name_pattern = None
+ else:
+ self._name_pattern = name
+
+ return self
+
only_cyb(self)
+
+
+¶返回结果中只包含创业板股票
+ +omicron/models/security.py
def only_cyb(self) -> "Query":
+ """返回结果中只包含创业板股票"""
+ self._only_cyb = True # 高优先级
+ self._exclude_cyb = False
+ self._only_kcb = False
+ self._only_st = False
+ return self
+
only_kcb(self)
+
+
+¶返回结果中只包含科创板股票
+ +omicron/models/security.py
def only_kcb(self) -> "Query":
+ """返回结果中只包含科创板股票"""
+ self._only_kcb = True # 高优先级
+ self._exclude_kcb = False
+ self._only_cyb = False
+ self._only_st = False
+ return self
+
only_st(self)
+
+
+¶返回结果中只包含ST类型的证券
+ +omicron/models/security.py
def only_st(self) -> "Query":
+ """返回结果中只包含ST类型的证券"""
+ self._only_st = True # 高优先级
+ self._exclude_st = False
+ self._only_kcb = False
+ self._only_cyb = False
+ return self
+
types(self, types)
+
+
+¶选择类型在types
中的证券品种
如果不调用此方法,默认选择所有股票类型。 +如果调用此方法但不传入参数,默认选择指数+股票
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
types |
+ List[str] |
+ 有效的类型包括: 对股票指数而言是('index', 'stock'),对基金而言则是('etf', 'fjb', 'mmf', 'reits', 'fja', 'fjm', 'lof') |
+ required | +
omicron/models/security.py
def types(self, types: List[str]) -> "Query":
+ """选择类型在`types`中的证券品种
+
+ 如果不调用此方法,默认选择所有股票类型。
+ 如果调用此方法但不传入参数,默认选择指数+股票
+ Args:
+ types: 有效的类型包括: 对股票指数而言是('index', 'stock'),对基金而言则是('etf', 'fjb', 'mmf', 'reits', 'fja', 'fjm', 'lof')
+ """
+ if types is None or isinstance(types, List) is False:
+ return self
+
+ if len(types) == 0:
+ self._type_pattern = ["index", "stock"]
+ else:
+ tmp = set(types)
+ self._type_pattern = list(tmp)
+
+ return self
+
+Security
+
+
+
+¶omicron/models/security.py
class Security:
+ _securities = []
+ _securities_date = None
+ _security_types = set()
+ _stocks = []
+
+ @classmethod
+ async def init(cls):
+ """初始化Security.
+
+ 一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过`omicron.init`)时,被自动调用。
+
+ Raises:
+ DataNotReadyError: 如果omicron未初始化,或者cache中未加载最新证券列表,则抛出此异常。
+ """
+ # read all securities from redis, 7111 records now
+ # {'index', 'stock'}
+ # {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ if len(cls._securities) > 100:
+ return True
+
+ secs = await cls.load_securities()
+ if secs is None or len(secs) == 0: # pragma: no cover
+ raise DataNotReadyError(
+ "No securities in cache, make sure you have called omicron.init() first."
+ )
+
+ print("init securities done")
+ return True
+
+ @classmethod
+ async def load_securities(cls):
+ """加载所有证券的信息,并缓存到内存中
+
+ 一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过`omicron.init`)时,被自动调用。
+ """
+ secs = await cache.security.lrange("security:all", 0, -1)
+ if len(secs) != 0:
+ # using np.datetime64[s]
+ _securities = np.array(
+ [tuple(x.split(",")) for x in secs], dtype=security_info_dtype
+ )
+
+ # 更新证券类型列表
+ cls._securities = _securities
+ cls._security_types = set(_securities["type"])
+ cls._stocks = _securities[
+ (_securities["type"] == "stock") | (_securities["type"] == "index")
+ ]
+ logger.info(
+ "%d securities loaded, types: %s", len(_securities), cls._security_types
+ )
+
+ date_in_cache = await cache.security.get("security:latest_date")
+ if date_in_cache is not None:
+ cls._securities_date = arrow.get(date_in_cache).date()
+ else:
+ cls._securities_date = datetime.date.today()
+
+ return _securities
+ else: # pragma: no cover
+ return None
+
+ @classmethod
+ async def get_security_types(cls):
+ if cls._security_types:
+ return list(cls._security_types)
+ else:
+ return None
+
+ @classmethod
+ def get_stock(cls, code) -> NDArray[security_info_dtype]:
+ """根据`code`来查找对应的股票(含指数)对象信息。
+
+ 如果您只有股票代码,想知道该代码对应的股票名称、别名(显示名)、上市日期等信息,就可以使用此方法来获取相关信息。
+
+ 返回类型为`security_info_dtype`的numpy数组,但仅包含一个元素。您可以象字典一样存取它,比如
+ ```python
+ item = Security.get_stock("000001.XSHE")
+ print(item["alias"])
+ ```
+ 显示为"平安银行"
+
+ Args:
+ code: 待查询的股票/指数代码
+
+ Returns:
+ 类型为`security_info_dtype`的numpy数组,但仅包含一个元素
+ """
+ if len(cls._securities) == 0:
+ return None
+
+ tmp = cls._securities[cls._securities["code"] == code]
+ if len(tmp) > 0:
+ if tmp["type"] in ["stock", "index"]:
+ return tmp[0]
+
+ return None
+
+ @classmethod
+ def fuzzy_match_ex(cls, query: str) -> Dict[str, Tuple]:
+ # fixme: 此方法与Stock.fuzzy_match重复,并且进行了类型限制,使得其不适合放在Security里,以及作为一个通用方法
+
+ query = query.upper()
+ if re.match(r"\d+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._securities
+ if sec["code"].find(query) != -1 and sec["type"] == "stock"
+ }
+ elif re.match(r"[A-Z]+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._securities
+ if sec["name"].startswith(query) and sec["type"] == "stock"
+ }
+ else:
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._securities
+ if sec["alias"].find(query) != -1 and sec["type"] == "stock"
+ }
+
+ @classmethod
+ async def info(cls, code, date=None):
+ _obj = await cls.query_security_via_date(code, date)
+ if _obj is None:
+ return None
+
+ # "_time", "code", "type", "alias", "end", "ipo", "name"
+ d1 = convert_nptime_to_datetime(_obj["ipo"]).date()
+ d2 = convert_nptime_to_datetime(_obj["end"]).date()
+ return {
+ "type": _obj["type"],
+ "display_name": _obj["alias"],
+ "alias": _obj["alias"],
+ "end": d2,
+ "start": d1,
+ "name": _obj["name"],
+ }
+
+ @classmethod
+ async def name(cls, code, date=None):
+ _security = await cls.query_security_via_date(code, date)
+ if _security is None:
+ return None
+ return _security["name"]
+
+ @classmethod
+ async def alias(cls, code, date=None):
+ return await cls.display_name(code, date)
+
+ @classmethod
+ async def display_name(cls, code, date=None):
+ _security = await cls.query_security_via_date(code, date)
+ if _security is None:
+ return None
+ return _security["alias"]
+
+ @classmethod
+ async def start_date(cls, code, date=None):
+ _security = await cls.query_security_via_date(code, date)
+ if _security is None:
+ return None
+ return convert_nptime_to_datetime(_security["ipo"]).date()
+
+ @classmethod
+ async def end_date(cls, code, date=None):
+ _security = await cls.query_security_via_date(code, date)
+ if _security is None:
+ return None
+ return convert_nptime_to_datetime(_security["end"]).date()
+
+ @classmethod
+ async def security_type(cls, code, date=None) -> SecurityType:
+ _security = await cls.query_security_via_date(code, date)
+ if _security is None:
+ return None
+ return _security["type"]
+
+ @classmethod
+ async def query_security_via_date(cls, code: str, date: datetime.date = None):
+ if date is None: # 从内存中查找,如果缓存中的数据已更新,重新加载到内存
+ date_in_cache = await cache.security.get("security:latest_date")
+ if date_in_cache is not None:
+ date = arrow.get(date_in_cache).date()
+ if date > cls._securities_date:
+ await cls.load_securities()
+ results = cls._securities[cls._securities["code"] == code]
+ else: # 从influxdb查找
+ date = tf.day_shift(date, 0)
+ results = await cls.load_securities_from_db(date, code)
+
+ if results is not None and len(results) > 0:
+ return results[0]
+ else:
+ return None
+
+ @classmethod
+ def select(cls, date: datetime.date = None) -> Query:
+ if date is None:
+ return Query(target_date=None)
+ else:
+ return Query(target_date=date)
+
+ @classmethod
+ async def update_secs_cache(cls, dt: datetime.date, securities: List[Tuple]):
+ """更新证券列表到缓存数据库中
+
+ Args:
+ dt: 证券列表归属的日期
+ securities: 证券列表, 元素为元组,分别为代码、别名、名称、IPO日期、退市日和证券类型
+ """
+ # stock: {'index', 'stock'}
+ # funds: {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ key = "security:all"
+ pipeline = cache.security.pipeline()
+ pipeline.delete(key)
+ for code, alias, name, start, end, _type in securities:
+ pipeline.rpush(key, f"{code},{alias},{name},{start}," f"{end},{_type}")
+ await pipeline.execute()
+ logger.info("all securities saved to cache %s, %d secs", key, len(securities))
+
+ # update latest date info
+ await cache.security.set("security:latest_date", dt.strftime("%Y-%m-%d"))
+
+ @classmethod
+ async def save_securities(cls, securities: List[str], dt: datetime.date):
+ """保存指定的证券信息到缓存中,并且存入influxdb,定时job调用本接口
+
+ Args:
+ securities: 证券代码列表。
+ """
+ # stock: {'index', 'stock'}
+ # funds: {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ if dt is None or len(securities) == 0:
+ return
+
+ measurement = "security_list"
+ client = get_influx_client()
+
+ # code, alias, name, start, end, type
+ security_list = np.array(
+ [
+ (dt, x[0], f"{x[0]},{x[1]},{x[2]},{x[3]},{x[4]},{x[5]}")
+ for x in securities
+ ],
+ dtype=security_db_dtype,
+ )
+ await client.save(
+ security_list, measurement, time_key="frame", tag_keys=["code"]
+ )
+
+ @classmethod
+ async def load_securities_from_db(
+ cls, target_date: datetime.date, code: str = None
+ ):
+ if target_date is None:
+ return None
+
+ client = get_influx_client()
+ measurement = "security_list"
+
+ flux = (
+ Flux()
+ .measurement(measurement)
+ .range(target_date, target_date)
+ .bucket(client._bucket)
+ .fields(["info"])
+ )
+ if code is not None and len(code) > 0:
+ flux.tags({"code": code})
+
+ data = await client.query(flux)
+ if len(data) == 2: # \r\n
+ return None
+
+ ds = DataframeDeserializer(
+ sort_values="_time",
+ usecols=["_time", "code", "info"],
+ time_col="_time",
+ engine="c",
+ )
+ actual = ds(data)
+ secs = actual.to_records(index=False)
+
+ if len(secs) != 0:
+ # "_time", "code", "code, alias, name, start, end, type"
+ _securities = np.array(
+ [tuple(x["info"].split(",")) for x in secs], dtype=security_info_dtype
+ )
+ return _securities
+ else:
+ return None
+
+ @classmethod
+ async def get_datescope_from_db(cls):
+ # fixme: 函数名无法反映用途,需要增加文档注释,说明该函数的作用,或者不应该出现在此类中?
+ client = get_influx_client()
+ measurement = "security_list"
+
+ date1 = arrow.get("2005-01-01").date()
+ date2 = arrow.now().naive.date()
+
+ flux = (
+ Flux()
+ .measurement(measurement)
+ .range(date1, date2)
+ .bucket(client._bucket)
+ .tags({"code": "000001.XSHE"})
+ )
+
+ data = await client.query(flux)
+ if len(data) == 2: # \r\n
+ return None, None
+
+ ds = DataframeDeserializer(
+ sort_values="_time", usecols=["_time"], time_col="_time", engine="c"
+ )
+ actual = ds(data)
+ secs = actual.to_records(index=False)
+
+ if len(secs) != 0:
+ d1 = convert_nptime_to_datetime(secs[0]["_time"])
+ d2 = convert_nptime_to_datetime(secs[len(secs) - 1]["_time"])
+ return d1.date(), d2.date()
+ else:
+ return None, None
+
+ @classmethod
+ async def _notify_special_bonusnote(cls, code, note, cancel_date):
+ # fixme: 这个函数应该出现在omega中?
+ default_cancel_date = datetime.date(2099, 1, 1) # 默认无取消公告
+ # report this special event to notify user
+ if cancel_date != default_cancel_date:
+ ding("security %s, bonus_cancel_pub_date %s" % (code, cancel_date))
+
+ if note.find("流通") != -1: # 检查是否有“流通股”文字
+ ding("security %s, special xrxd note: %s" % (code, note))
+
+ @classmethod
+ async def save_xrxd_reports(cls, reports: List[str], dt: datetime.date):
+ # fixme: 此函数应该属于omega?
+
+ """保存1年内的分红送股信息,并且存入influxdb,定时job调用本接口
+
+ Args:
+ reports: 分红送股公告
+ """
+ # code(0), a_xr_date, board_plan_bonusnote, bonus_ratio_rmb(3), dividend_ratio, transfer_ratio(5),
+ # at_bonus_ratio_rmb(6), report_date, plan_progress, implementation_bonusnote, bonus_cancel_pub_date(10)
+
+ if len(reports) == 0 or dt is None:
+ return
+
+ # read reports from db and convert to dict map
+ reports_in_db = {}
+ dt_start = dt - datetime.timedelta(days=366) # 往前回溯366天
+ dt_end = dt + datetime.timedelta(days=366) # 往后延长366天
+ existing_records = await cls._load_xrxd_from_db(None, dt_start, dt_end)
+ for record in existing_records:
+ code = record[0]
+ if code not in reports_in_db:
+ reports_in_db[code] = [record]
+ else:
+ reports_in_db[code].append(record)
+
+ records = [] # 准备写入db
+
+ for x in reports:
+ code = x[0]
+ note = x[2]
+ cancel_date = x[10]
+
+ existing_items = reports_in_db.get(code, None)
+ if existing_items is None: # 新记录
+ record = (
+ x[1],
+ x[0],
+ f"{x[0]}|{x[1]}|{x[2]}|{x[3]}|{x[4]}|{x[5]}|{x[6]}|{x[7]}|{x[8]}|{x[9]}|{x[10]}",
+ )
+ records.append(record)
+ await cls._notify_special_bonusnote(code, note, cancel_date)
+ else:
+ new_record = True
+ for item in existing_items:
+ existing_date = convert_nptime_to_datetime(item[1]).date()
+ if existing_date == x[1]: # 如果xr_date相同,不更新
+ new_record = False
+ continue
+ if new_record:
+ record = (
+ x[1],
+ x[0],
+ f"{x[0]}|{x[1]}|{x[2]}|{x[3]}|{x[4]}|{x[5]}|{x[6]}|{x[7]}|{x[8]}|{x[9]}|{x[10]}",
+ )
+ records.append(record)
+ await cls._notify_special_bonusnote(code, note, cancel_date)
+
+ logger.info("save_xrxd_reports, %d records to be saved", len(records))
+ if len(records) == 0:
+ return
+
+ measurement = "security_xrxd_reports"
+ client = get_influx_client()
+ # a_xr_date(_time), code(tag), info
+ report_list = np.array(records, dtype=security_db_dtype)
+ await client.save(report_list, measurement, time_key="frame", tag_keys=["code"])
+
+ @classmethod
+ async def _load_xrxd_from_db(
+ cls, code, dt_start: datetime.date, dt_end: datetime.date
+ ):
+ if dt_start is None or dt_end is None:
+ return []
+
+ client = get_influx_client()
+ measurement = "security_xrxd_reports"
+
+ flux = (
+ Flux()
+ .measurement(measurement)
+ .range(dt_start, dt_end)
+ .bucket(client._bucket)
+ .fields(["info"])
+ )
+ if code is not None and len(code) > 0:
+ flux.tags({"code": code})
+
+ data = await client.query(flux)
+ if len(data) == 2: # \r\n
+ return []
+
+ ds = DataframeDeserializer(
+ sort_values="_time",
+ usecols=["_time", "code", "info"],
+ time_col="_time",
+ engine="c",
+ )
+ actual = ds(data)
+ secs = actual.to_records(index=False)
+
+ if len(secs) != 0:
+ _reports = np.array(
+ [tuple(x["info"].split("|")) for x in secs], dtype=xrxd_info_dtype
+ )
+ return _reports
+ else:
+ return []
+
+ @classmethod
+ async def get_xrxd_info(cls, dt: datetime.date, code: str = None):
+ if dt is None:
+ return None
+
+ # code(0), a_xr_date, board_plan_bonusnote, bonus_ratio_rmb(3), dividend_ratio, transfer_ratio(5),
+ # at_bonus_ratio_rmb(6), report_date, plan_progress, implementation_bonusnote, bonus_cancel_pub_date(10)
+ reports = await cls._load_xrxd_from_db(code, dt, dt)
+ if len(reports) == 0:
+ return None
+
+ readable_reports = []
+ for report in reports:
+ xr_date = convert_nptime_to_datetime(report[1]).date()
+ readable_reports.append(
+ {
+ "code": report[0],
+ "xr_date": xr_date,
+ "bonus": report[3],
+ "dividend": report[4],
+ "transfer": report[5],
+ "bonusnote": report[2],
+ }
+ )
+
+ return readable_reports
+
get_stock(code)
+
+
+ classmethod
+
+
+¶根据code
来查找对应的股票(含指数)对象信息。
如果您只有股票代码,想知道该代码对应的股票名称、别名(显示名)、上市日期等信息,就可以使用此方法来获取相关信息。
+返回类型为security_info_dtype
的numpy数组,但仅包含一个元素。您可以象字典一样存取它,比如
+
1 +2 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ + | 待查询的股票/指数代码 |
+ required | +
Returns:
+Type | +Description | +
---|---|
numpy.ndarray[Any, numpy.dtype[[('code', 'O'), ('alias', 'O'), ('name', 'O'), ('ipo', 'datetime64[s]'), ('end', 'datetime64[s]'), ('type', 'O')]]] |
+ 类型为 |
+
omicron/models/security.py
@classmethod
+def get_stock(cls, code) -> NDArray[security_info_dtype]:
+ """根据`code`来查找对应的股票(含指数)对象信息。
+
+ 如果您只有股票代码,想知道该代码对应的股票名称、别名(显示名)、上市日期等信息,就可以使用此方法来获取相关信息。
+
+ 返回类型为`security_info_dtype`的numpy数组,但仅包含一个元素。您可以象字典一样存取它,比如
+ ```python
+ item = Security.get_stock("000001.XSHE")
+ print(item["alias"])
+ ```
+ 显示为"平安银行"
+
+ Args:
+ code: 待查询的股票/指数代码
+
+ Returns:
+ 类型为`security_info_dtype`的numpy数组,但仅包含一个元素
+ """
+ if len(cls._securities) == 0:
+ return None
+
+ tmp = cls._securities[cls._securities["code"] == code]
+ if len(tmp) > 0:
+ if tmp["type"] in ["stock", "index"]:
+ return tmp[0]
+
+ return None
+
init()
+
+
+ async
+ classmethod
+
+
+¶初始化Security.
+一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过omicron.init
)时,被自动调用。
Exceptions:
+Type | +Description | +
---|---|
DataNotReadyError |
+ 如果omicron未初始化,或者cache中未加载最新证券列表,则抛出此异常。 |
+
omicron/models/security.py
@classmethod
+async def init(cls):
+ """初始化Security.
+
+ 一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过`omicron.init`)时,被自动调用。
+
+ Raises:
+ DataNotReadyError: 如果omicron未初始化,或者cache中未加载最新证券列表,则抛出此异常。
+ """
+ # read all securities from redis, 7111 records now
+ # {'index', 'stock'}
+ # {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ if len(cls._securities) > 100:
+ return True
+
+ secs = await cls.load_securities()
+ if secs is None or len(secs) == 0: # pragma: no cover
+ raise DataNotReadyError(
+ "No securities in cache, make sure you have called omicron.init() first."
+ )
+
+ print("init securities done")
+ return True
+
load_securities()
+
+
+ async
+ classmethod
+
+
+¶加载所有证券的信息,并缓存到内存中
+一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过omicron.init
)时,被自动调用。
omicron/models/security.py
@classmethod
+async def load_securities(cls):
+ """加载所有证券的信息,并缓存到内存中
+
+ 一般而言,omicron的使用者无须调用此方法,它会在omicron初始化(通过`omicron.init`)时,被自动调用。
+ """
+ secs = await cache.security.lrange("security:all", 0, -1)
+ if len(secs) != 0:
+ # using np.datetime64[s]
+ _securities = np.array(
+ [tuple(x.split(",")) for x in secs], dtype=security_info_dtype
+ )
+
+ # 更新证券类型列表
+ cls._securities = _securities
+ cls._security_types = set(_securities["type"])
+ cls._stocks = _securities[
+ (_securities["type"] == "stock") | (_securities["type"] == "index")
+ ]
+ logger.info(
+ "%d securities loaded, types: %s", len(_securities), cls._security_types
+ )
+
+ date_in_cache = await cache.security.get("security:latest_date")
+ if date_in_cache is not None:
+ cls._securities_date = arrow.get(date_in_cache).date()
+ else:
+ cls._securities_date = datetime.date.today()
+
+ return _securities
+ else: # pragma: no cover
+ return None
+
save_securities(securities, dt)
+
+
+ async
+ classmethod
+
+
+¶保存指定的证券信息到缓存中,并且存入influxdb,定时job调用本接口
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
securities |
+ List[str] |
+ 证券代码列表。 |
+ required | +
omicron/models/security.py
@classmethod
+async def save_securities(cls, securities: List[str], dt: datetime.date):
+ """保存指定的证券信息到缓存中,并且存入influxdb,定时job调用本接口
+
+ Args:
+ securities: 证券代码列表。
+ """
+ # stock: {'index', 'stock'}
+ # funds: {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ if dt is None or len(securities) == 0:
+ return
+
+ measurement = "security_list"
+ client = get_influx_client()
+
+ # code, alias, name, start, end, type
+ security_list = np.array(
+ [
+ (dt, x[0], f"{x[0]},{x[1]},{x[2]},{x[3]},{x[4]},{x[5]}")
+ for x in securities
+ ],
+ dtype=security_db_dtype,
+ )
+ await client.save(
+ security_list, measurement, time_key="frame", tag_keys=["code"]
+ )
+
save_xrxd_reports(reports, dt)
+
+
+ async
+ classmethod
+
+
+¶保存1年内的分红送股信息,并且存入influxdb,定时job调用本接口
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
reports |
+ List[str] |
+ 分红送股公告 |
+ required | +
omicron/models/security.py
@classmethod
+async def save_xrxd_reports(cls, reports: List[str], dt: datetime.date):
+ # fixme: 此函数应该属于omega?
+
+ """保存1年内的分红送股信息,并且存入influxdb,定时job调用本接口
+
+ Args:
+ reports: 分红送股公告
+ """
+ # code(0), a_xr_date, board_plan_bonusnote, bonus_ratio_rmb(3), dividend_ratio, transfer_ratio(5),
+ # at_bonus_ratio_rmb(6), report_date, plan_progress, implementation_bonusnote, bonus_cancel_pub_date(10)
+
+ if len(reports) == 0 or dt is None:
+ return
+
+ # read reports from db and convert to dict map
+ reports_in_db = {}
+ dt_start = dt - datetime.timedelta(days=366) # 往前回溯366天
+ dt_end = dt + datetime.timedelta(days=366) # 往后延长366天
+ existing_records = await cls._load_xrxd_from_db(None, dt_start, dt_end)
+ for record in existing_records:
+ code = record[0]
+ if code not in reports_in_db:
+ reports_in_db[code] = [record]
+ else:
+ reports_in_db[code].append(record)
+
+ records = [] # 准备写入db
+
+ for x in reports:
+ code = x[0]
+ note = x[2]
+ cancel_date = x[10]
+
+ existing_items = reports_in_db.get(code, None)
+ if existing_items is None: # 新记录
+ record = (
+ x[1],
+ x[0],
+ f"{x[0]}|{x[1]}|{x[2]}|{x[3]}|{x[4]}|{x[5]}|{x[6]}|{x[7]}|{x[8]}|{x[9]}|{x[10]}",
+ )
+ records.append(record)
+ await cls._notify_special_bonusnote(code, note, cancel_date)
+ else:
+ new_record = True
+ for item in existing_items:
+ existing_date = convert_nptime_to_datetime(item[1]).date()
+ if existing_date == x[1]: # 如果xr_date相同,不更新
+ new_record = False
+ continue
+ if new_record:
+ record = (
+ x[1],
+ x[0],
+ f"{x[0]}|{x[1]}|{x[2]}|{x[3]}|{x[4]}|{x[5]}|{x[6]}|{x[7]}|{x[8]}|{x[9]}|{x[10]}",
+ )
+ records.append(record)
+ await cls._notify_special_bonusnote(code, note, cancel_date)
+
+ logger.info("save_xrxd_reports, %d records to be saved", len(records))
+ if len(records) == 0:
+ return
+
+ measurement = "security_xrxd_reports"
+ client = get_influx_client()
+ # a_xr_date(_time), code(tag), info
+ report_list = np.array(records, dtype=security_db_dtype)
+ await client.save(report_list, measurement, time_key="frame", tag_keys=["code"])
+
update_secs_cache(dt, securities)
+
+
+ async
+ classmethod
+
+
+¶更新证券列表到缓存数据库中
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
dt |
+ date |
+ 证券列表归属的日期 |
+ required | +
securities |
+ List[Tuple] |
+ 证券列表, 元素为元组,分别为代码、别名、名称、IPO日期、退市日和证券类型 |
+ required | +
omicron/models/security.py
@classmethod
+async def update_secs_cache(cls, dt: datetime.date, securities: List[Tuple]):
+ """更新证券列表到缓存数据库中
+
+ Args:
+ dt: 证券列表归属的日期
+ securities: 证券列表, 元素为元组,分别为代码、别名、名称、IPO日期、退市日和证券类型
+ """
+ # stock: {'index', 'stock'}
+ # funds: {'fjb', 'mmf', 'reits', 'fja', 'fjm'}
+ # {'etf', 'lof'}
+ key = "security:all"
+ pipeline = cache.security.pipeline()
+ pipeline.delete(key)
+ for code, alias, name, start, end, _type in securities:
+ pipeline.rpush(key, f"{code},{alias},{name},{start}," f"{end},{_type}")
+ await pipeline.execute()
+ logger.info("all securities saved to cache %s, %d secs", key, len(securities))
+
+ # update latest date info
+ await cache.security.set("security:latest_date", dt.strftime("%Y-%m-%d"))
+
+Stock (Security)
+
+
+
+
+¶Stock对象用于归集某支证券(股票和指数,不包括其它投资品种)的相关信息,比如行情数据(OHLC等)、市值数据、所属概念分类等。
+ +omicron/models/stock.py
class Stock(Security):
+ """
+ Stock对象用于归集某支证券(股票和指数,不包括其它投资品种)的相关信息,比如行情数据(OHLC等)、市值数据、所属概念分类等。
+ """
+
+ _is_cache_empty = True
+
+ def __init__(self, code: str):
+ self._code = code
+ self._stock = self.get_stock(code)
+ assert self._stock, "系统中不存在该code"
+ (_, self._display_name, self._name, ipo, end, _type) = self._stock
+ self._start_date = convert_nptime_to_datetime(ipo).date()
+ self._end_date = convert_nptime_to_datetime(end).date()
+ self._type = SecurityType(_type)
+
+ @classmethod
+ def choose_listed(cls, dt: datetime.date, types: List[str] = ["stock", "index"]):
+ cond = np.array([False] * len(cls._stocks))
+ dt = datetime.datetime.combine(dt, datetime.time())
+
+ for type_ in types:
+ cond |= cls._stocks["type"] == type_
+ result = cls._stocks[cond]
+ result = result[result["end"] > dt]
+ result = result[result["ipo"] <= dt]
+ # result = np.array(result, dtype=cls.stock_info_dtype)
+ return result["code"].tolist()
+
+ @classmethod
+ def fuzzy_match(cls, query: str) -> Dict[str, Tuple]:
+ """对股票/指数进行模糊匹配查找
+
+ query可以是股票/指数代码,也可以是字母(按name查找),也可以是汉字(按显示名查找)
+
+ Args:
+ query (str): 查询字符串
+
+ Returns:
+ Dict[str, Tuple]: 查询结果,其中Tuple为(code, display_name, name, start, end, type)
+ """
+ query = query.upper()
+ if re.match(r"\d+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["code"].startswith(query)
+ }
+ elif re.match(r"[A-Z]+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["name"].startswith(query)
+ }
+ else:
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["alias"].find(query) != -1
+ }
+
+ def __str__(self):
+ return f"{self.display_name}[{self.code}]"
+
+ @property
+ def ipo_date(self) -> datetime.date:
+ return self._start_date
+
+ @property
+ def display_name(self) -> str:
+ return self._display_name
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def end_date(self) -> datetime.date:
+ return self._end_date
+
+ @property
+ def code(self) -> str:
+ return self._code
+
+ @property
+ def sim_code(self) -> str:
+ return re.sub(r"\.XSH[EG]", "", self.code)
+
+ @property
+ def security_type(self) -> SecurityType:
+ """返回证券类型
+
+ Returns:
+ SecurityType: [description]
+ """
+ return self._type
+
+ @staticmethod
+ def simplify_code(code) -> str:
+ return re.sub(r"\.XSH[EG]", "", code)
+
+ @staticmethod
+ def format_code(code) -> str:
+ """新三板和北交所的股票, 暂不支持, 默认返回None
+ 上证A股: 600、601、603、605
+ 深证A股: 000、001
+ 中小板: 002、003
+ 创业板: 300/301
+ 科创板: 688
+ 新三板: 82、83、87、88、430、420、400
+ 北交所: 43、83、87、88
+ """
+ if not code or len(code) != 6:
+ return None
+
+ prefix = code[0]
+ if prefix in ("0", "3"):
+ return f"{code}.XSHE"
+ elif prefix == "6":
+ return f"{code}.XSHG"
+ else:
+ return None
+
+ def days_since_ipo(self) -> int:
+ """获取上市以来经过了多少个交易日
+
+ 由于受交易日历限制(2005年1月4日之前的交易日历没有),对于在之前上市的品种,都返回从2005年1月4日起的日期。
+
+ Returns:
+ int: [description]
+ """
+ epoch_start = arrow.get("2005-01-04").date()
+ ipo_day = self.ipo_date if self.ipo_date > epoch_start else epoch_start
+ return tf.count_day_frames(ipo_day, arrow.now().date())
+
+ @staticmethod
+ def qfq(bars: BarsArray) -> BarsArray:
+ """对行情数据执行前复权操作"""
+ # todo: 这里可以优化
+ if bars.size == 0:
+ return bars
+
+ last = bars[-1]["factor"]
+ for field in ["open", "high", "low", "close", "volume"]:
+ bars[field] = bars[field] * (bars["factor"] / last)
+
+ return bars
+
+ @classmethod
+ async def batch_get_min_level_bars_in_range(
+ cls,
+ codes: List[str],
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame,
+ fq: bool = True,
+ ) -> Generator[Dict[str, BarsArray], None, None]:
+ """获取多支股票(指数)在[start, end)时间段内的行情数据
+
+ 如果要获取的行情数据是分钟级别(即1m, 5m, 15m, 30m和60m),使用本接口。
+
+ 停牌数据处理请见[get_bars][omicron.models.stock.Stock.get_bars]。
+
+ 本函数返回一个迭代器,使用方法示例:
+ ```
+ async for code, bars in Stock.batch_get_min_level_bars_in_range(...):
+ print(code, bars)
+ ```
+
+ 如果`end`不在`frame_type`所属的边界点上,那么,如果`end`大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到`tf.floor(end, frame_type)`。
+
+ Args:
+ codes: 股票/指数代码列表
+ frame_type: 帧类型
+ start: 起始时间
+ end: 结束时间。如果未指明,则取当前时间。
+ fq: 是否进行复权,如果是,则进行前复权。Defaults to True.
+
+ Returns:
+ Generator[Dict[str, BarsArray], None, None]: 迭代器,每次返回一个字典,其中key为代码,value为行情数据
+ """
+ closed_end = tf.floor(end, frame_type)
+ n = tf.count_frames(start, closed_end, frame_type)
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+ batch_size = max(1, max_query_size // n)
+ ff = tf.first_min_frame(datetime.datetime.now(), frame_type)
+
+ for i in range(0, len(codes), batch_size):
+ batch_codes = codes[i : i + batch_size]
+
+ if end < ff:
+ part1 = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, end
+ )
+ part2 = pd.DataFrame([], columns=bars_dtype_with_code.names)
+ elif start >= ff:
+ part1 = pd.DataFrame([], columns=bars_dtype_with_code.names)
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, n, end, batch_codes
+ )
+ cached = cached[cached["frame"] >= start]
+ part2 = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+ else:
+ part1 = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, ff
+ )
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, n, end, batch_codes
+ )
+ part2 = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+
+ df = pd.concat([part1, part2])
+
+ for code in batch_codes:
+ filtered = df[df["code"] == code][bars_cols]
+ bars = filtered.to_records(index=False).astype(bars_dtype)
+ if fq:
+ bars = cls.qfq(bars)
+
+ yield code, bars
+
+ @classmethod
+ async def batch_get_day_level_bars_in_range(
+ cls,
+ codes: List[str],
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame,
+ fq: bool = True,
+ ) -> Generator[Dict[str, BarsArray], None, None]:
+ """获取多支股票(指数)在[start, end)时间段内的行情数据
+
+ 如果要获取的行情数据是日线级别(即1d, 1w, 1M),使用本接口。
+
+ 停牌数据处理请见[get_bars][omicron.models.stock.Stock.get_bars]。
+
+ 本函数返回一个迭代器,使用方法示例:
+ ```
+ async for code, bars in Stock.batch_get_day_level_bars_in_range(...):
+ print(code, bars)
+ ```
+
+ 如果`end`不在`frame_type`所属的边界点上,那么,如果`end`大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到`tf.floor(end, frame_type)`。
+
+ Args:
+ codes: 代码列表
+ frame_type: 帧类型
+ start: 起始时间
+ end: 结束时间
+ fq: 是否进行复权,如果是,则进行前复权。Defaults to True.
+
+ Returns:
+ Generator[Dict[str, BarsArray], None, None]: 迭代器,每次返回一个字典,其中key为代码,value为行情数据
+ """
+ today = datetime.datetime.now().date()
+ # 日线,end不等于最后交易日,此时已无缓存
+ if frame_type == FrameType.DAY and end == tf.floor(today, frame_type):
+ from_cache = True
+ elif frame_type != FrameType.DAY and start > tf.floor(today, frame_type):
+ from_cache = True
+ else:
+ from_cache = False
+
+ n = tf.count_frames(start, end, frame_type)
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+ batch_size = max(max_query_size // n, 1)
+
+ for i in range(0, len(codes), batch_size):
+ batch_codes = codes[i : i + batch_size]
+ persisted = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, end
+ )
+
+ if from_cache:
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, 1, end, batch_codes
+ )
+ cached = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+
+ df = pd.concat([persisted, cached])
+ else:
+ df = persisted
+
+ for code in batch_codes:
+ filtered = df[df["code"] == code][bars_cols]
+ bars = filtered.to_records(index=False).astype(bars_dtype)
+ if fq:
+ bars = cls.qfq(bars)
+
+ yield code, bars
+
+ @classmethod
+ async def get_bars_in_range(
+ cls,
+ code: str,
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame = None,
+ fq=True,
+ unclosed=True,
+ ) -> BarsArray:
+ """获取指定证券(`code`)在[`start`, `end`]期间帧类型为`frame_type`的行情数据。
+
+ Args:
+ code : 证券代码
+ frame_type : 行情数据的帧类型
+ start : 起始时间
+ end : 结束时间,如果为None,则表明取到当前时间。
+ fq : 是否对行情数据执行前复权操作
+ unclosed : 是否包含未收盘的数据
+ """
+ now = datetime.datetime.now()
+
+ if frame_type in tf.day_level_frames:
+ end = end or now.date()
+ if unclosed and tf.day_shift(end, 0) == now.date():
+ part2 = await cls._get_cached_bars_n(code, 1, frame_type)
+ else:
+ part2 = np.array([], dtype=bars_dtype)
+
+ # get rest from persisted
+ part1 = await cls._get_persisted_bars_in_range(code, frame_type, start, end)
+ bars = np.concatenate((part1, part2))
+ else:
+ end = end or now
+ closed_end = tf.floor(end, frame_type)
+ ff_min1 = tf.first_min_frame(now, FrameType.MIN1)
+ if tf.day_shift(end, 0) < now.date() or end < ff_min1:
+ part1 = await cls._get_persisted_bars_in_range(
+ code, frame_type, start, end
+ )
+ part2 = np.array([], dtype=bars_dtype)
+ elif start >= ff_min1: # all in cache
+ part1 = np.array([], dtype=bars_dtype)
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ part2 = await cls._get_cached_bars_n(code, n, frame_type, end)
+ part2 = part2[part2["frame"] >= start]
+ else: # in both cache and persisted
+ ff = tf.first_min_frame(now, frame_type)
+ part1 = await cls._get_persisted_bars_in_range(
+ code, frame_type, start, ff
+ )
+ n = tf.count_frames(ff, closed_end, frame_type) + 1
+ part2 = await cls._get_cached_bars_n(code, n, frame_type, end)
+
+ if not unclosed:
+ part2 = part2[part2["frame"] <= closed_end]
+ bars = np.concatenate((part1, part2))
+
+ if fq:
+ return cls.qfq(bars)
+ else:
+ return bars
+
+ @classmethod
+ async def get_bars(
+ cls,
+ code: str,
+ n: int,
+ frame_type: FrameType,
+ end: Frame = None,
+ fq=True,
+ unclosed=True,
+ ) -> BarsArray:
+ """获取到`end`为止的`n`个行情数据。
+
+ 返回的数据是按照时间顺序递增排序的。在遇到停牌的情况时,该时段数据将被跳过,因此返回的记录可能不是交易日连续的,并且可能不足`n`个。
+
+ 如果系统当前没有到指定时间`end`的数据,将尽最大努力返回数据。调用者可以通过判断最后一条数据的时间是否等于`end`来判断是否获取到了全部数据。
+
+ Args:
+ code: 证券代码
+ n: 记录数
+ frame_type: 帧类型
+ end: 截止时间,如果未指明,则取当前时间
+ fq: 是否对返回记录进行复权。如果为`True`的话,则进行前复权。Defaults to True.
+ unclosed: 是否包含最新未收盘的数据? Defaults to True.
+
+ Returns:
+ 返回dtype为`coretypes.bars_dtype`的一维numpy数组。
+ """
+ now = datetime.datetime.now()
+ try:
+ cached = np.array([], dtype=bars_dtype)
+
+ if frame_type in tf.day_level_frames:
+ if end is None:
+ end = now.date()
+ elif type(end) == datetime.datetime:
+ end = end.date()
+ n0 = n
+ if unclosed:
+ cached = await cls._get_cached_bars_n(code, 1, frame_type)
+ if cached.size > 0:
+ # 如果缓存的未收盘日期 > end,则该缓存不是需要的
+ if cached[0]["frame"].item().date() > end:
+ cached = np.array([], dtype=bars_dtype)
+ else:
+ n0 = n - 1
+ else:
+ end = end or now
+ closed_frame = tf.floor(end, frame_type)
+
+ # fetch one more bar, in case we should discard unclosed bar
+ cached = await cls._get_cached_bars_n(code, n + 1, frame_type, end)
+ if not unclosed:
+ cached = cached[cached["frame"] <= closed_frame]
+
+ # n bars we need fetch from persisted db
+ n0 = n - cached.size
+ if n0 > 0:
+ if cached.size > 0:
+ end0 = cached[0]["frame"].item()
+ else:
+ end0 = end
+
+ bars = await cls._get_persisted_bars_n(code, frame_type, n0, end0)
+ merged = np.concatenate((bars, cached))
+ bars = merged[-n:]
+ else:
+ bars = cached[-n:]
+
+ if fq:
+ bars = cls.qfq(bars)
+ return bars
+ except Exception as e:
+ logger.exception(e)
+ logger.warning(
+ "failed to get bars for %s, %s, %s, %s", code, n, frame_type, end
+ )
+ raise
+
+ @classmethod
+ async def _get_persisted_bars_in_range(
+ cls, code: str, frame_type: FrameType, start: Frame, end: Frame = None
+ ) -> BarsArray:
+ """从持久化数据库中获取介于[`start`, `end`]间的行情记录
+
+ 如果`start`到`end`区间某支股票停牌,则会返回空数组。
+
+ Args:
+ code: 证券代码
+ frame_type: 帧类型
+ start: 起始时间
+ end: 结束时间,如果未指明,则取当前时间
+
+ Returns:
+ 返回dtype为`coretypes.bars_dtype`的一维numpy数组。
+ """
+ end = end or datetime.datetime.now()
+
+ keep_cols = ["_time"] + list(bars_cols[1:])
+
+ measurement = cls._measurement_name(frame_type)
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(start, end)
+ .measurement(measurement)
+ .fields(keep_cols)
+ .tags({"code": code})
+ )
+
+ serializer = DataframeDeserializer(
+ encoding="utf-8",
+ names=[
+ "_",
+ "table",
+ "result",
+ "frame",
+ "code",
+ "amount",
+ "close",
+ "factor",
+ "high",
+ "low",
+ "open",
+ "volume",
+ ],
+ engine="c",
+ skiprows=0,
+ header=0,
+ usecols=bars_cols,
+ parse_dates=["frame"],
+ )
+
+ client = get_influx_client()
+ result = await client.query(flux, serializer)
+ return result.to_records(index=False).astype(bars_dtype)
+
+ @classmethod
+ async def _get_persisted_bars_n(
+ cls, code: str, frame_type: FrameType, n: int, end: Frame = None
+ ) -> BarsArray:
+ """从持久化数据库中获取截止到`end`的`n`条行情记录
+
+ 如果`end`未指定,则取当前时间。
+
+ 基于influxdb查询的特性,在查询前,必须先根据`end`和`n`计算出起始时间,但如果在此期间某些股票有停牌,则无法返回的数据将小于`n`。而如果起始时间设置得足够早,虽然能满足返回数据条数的要求,但会带来性能上的损失。因此,我们在计算起始时间时,不是使用`n`来计算,而是使用了`min(n * 2, n + 20)`来计算起始时间,这样多数情况下,能够保证返回数据的条数为`n`条。
+
+ 返回的数据按`frame`进行升序排列。
+
+ Args:
+ code: 证券代码
+ frame_type: 帧类型
+ n: 返回结果数量
+ end: 结束时间,如果未指明,则取当前时间
+
+ Returns:
+ 返回dtype为`bars_dtype`的numpy数组
+ """
+ # check is needed since tags accept List as well
+ assert isinstance(code, str), "`code` must be a string"
+
+ end = end or datetime.datetime.now()
+ closed_end = tf.floor(end, frame_type)
+ start = tf.shift(closed_end, -min(2 * n, n + 20), frame_type)
+
+ keep_cols = ["_time"] + list(bars_cols[1:])
+
+ measurement = cls._measurement_name(frame_type)
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(start, end)
+ .measurement(measurement)
+ .fields(keep_cols)
+ .tags({"code": code})
+ .latest(n)
+ )
+
+ serializer = DataframeDeserializer(
+ encoding="utf-8",
+ names=[
+ "_",
+ "table",
+ "result",
+ "frame",
+ "code",
+ "amount",
+ "close",
+ "factor",
+ "high",
+ "low",
+ "open",
+ "volume",
+ ],
+ engine="c",
+ skiprows=0,
+ header=0,
+ usecols=bars_cols,
+ parse_dates=["frame"],
+ )
+
+ client = get_influx_client()
+ result = await client.query(flux, serializer)
+ return result.to_records(index=False).astype(bars_dtype)
+
+ @classmethod
+ async def _batch_get_persisted_bars_n(
+ cls, codes: List[str], frame_type: FrameType, n: int, end: Frame = None
+ ) -> pd.DataFrame:
+ """从持久化存储中获取`codes`指定的一批股票截止`end`时的`n`条记录。
+
+ 返回的数据按`frame`进行升序排列。如果不存在满足指定条件的查询结果,将返回空的DataFrame。
+
+ 基于influxdb查询的特性,在查询前,必须先根据`end`和`n`计算出起始时间,但如果在此期间某些股票有停牌,则无法返回的数据将小于`n`。如果起始时间设置的足够早,虽然能满足返回数据条数的要求,但会带来性能上的损失。因此,我们在计算起始时间时,不是使用`n`来计算,而是使用了`min(n * 2, n + 20)`来计算起始时间,这样多数情况下,能够保证返回数据的条数为`n`条。
+
+ Args:
+ codes: 证券代码列表。
+ frame_type: 帧类型
+ n: 返回结果数量
+ end: 结束时间,如果未指定,则使用当前时间
+
+ Returns:
+ DataFrame, columns为`code`, `frame`, `open`, `high`, `low`, `close`, `volume`, `amount`, `factor`
+
+ """
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+
+ if len(codes) * min(n + 20, 2 * n) > max_query_size:
+ raise BadParameterError(
+ f"codes的数量和n的乘积超过了influxdb的最大查询数量限制{max_query_size}"
+ )
+
+ end = end or datetime.datetime.now()
+ close_end = tf.floor(end, frame_type)
+ begin = tf.shift(close_end, -1 * min(n + 20, n * 2), frame_type)
+
+ # influxdb的查询结果格式类似于CSV,其列顺序为_, result_alias, table_seq, _time, tags, fields,其中tags和fields都是升序排列
+ keep_cols = ["code"] + list(bars_cols)
+ names = ["_", "result", "table", "frame", "code"]
+
+ # influxdb will return fields in the order of name ascending parallel
+ names.extend(sorted(bars_cols[1:]))
+
+ measurement = cls._measurement_name(frame_type)
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(begin, end)
+ .measurement(measurement)
+ .fields(keep_cols)
+ .latest(n)
+ )
+
+ if codes is not None:
+ assert isinstance(codes, list), "`codes` must be a list or None"
+ flux.tags({"code": codes})
+
+ deserializer = DataframeDeserializer(
+ names=names,
+ usecols=keep_cols,
+ encoding="utf-8",
+ time_col="frame",
+ engine="c",
+ )
+
+ client = get_influx_client()
+ return await client.query(flux, deserializer)
+
+ @classmethod
+ async def _batch_get_persisted_bars_in_range(
+ cls, codes: List[str], frame_type: FrameType, begin: Frame, end: Frame = None
+ ) -> pd.DataFrame:
+ """从持久化存储中获取`codes`指定的一批股票在`begin`和`end`之间的记录。
+
+ 返回的数据将按`frame`进行升序排列。
+ 注意,返回的数据有可能不是等长的,因为有的股票可能停牌。
+
+ Args:
+ codes: 证券代码列表。
+ frame_type: 帧类型
+ begin: 开始时间
+ end: 结束时间
+
+ Returns:
+ DataFrame, columns为`code`, `frame`, `open`, `high`, `low`, `close`, `volume`, `amount`, `factor`
+
+ """
+ end = end or datetime.datetime.now()
+
+ n = tf.count_frames(begin, end, frame_type)
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+ if len(codes) * n > max_query_size:
+ raise BadParameterError(
+ f"asked records is {len(codes) * n}, which is too large than {max_query_size}"
+ )
+
+ # influxdb的查询结果格式类似于CSV,其列顺序为_, result_alias, table_seq, _time, tags, fields,其中tags和fields都是升序排列
+ keep_cols = ["code"] + list(bars_cols)
+ names = ["_", "result", "table", "frame", "code"]
+
+ # influxdb will return fields in the order of name ascending parallel
+ names.extend(sorted(bars_cols[1:]))
+
+ measurement = cls._measurement_name(frame_type)
+ flux = (
+ Flux()
+ .bucket(cfg.influxdb.bucket_name)
+ .range(begin, end)
+ .measurement(measurement)
+ .fields(keep_cols)
+ )
+
+ flux.tags({"code": codes})
+
+ deserializer = DataframeDeserializer(
+ names=names,
+ usecols=keep_cols,
+ encoding="utf-8",
+ time_col="frame",
+ engine="c",
+ )
+
+ client = get_influx_client()
+ df = await client.query(flux, deserializer)
+ return df
+
+ @classmethod
+ async def batch_cache_bars(cls, frame_type: FrameType, bars: Dict[str, BarsArray]):
+ """缓存已收盘的分钟线和日线
+
+ 当缓存日线时,仅限于当日收盘后的第一次同步时调用。
+
+ Args:
+ frame_type: 帧类型
+ bars: 行情数据,其key为股票代码,其value为dtype为`bars_dtype`的一维numpy数组。
+
+ Raises:
+ RedisError: 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。
+ """
+ if frame_type == FrameType.DAY:
+ await cls.batch_cache_unclosed_bars(frame_type, bars)
+ return
+
+ pl = cache.security.pipeline()
+ for code, bars in bars.items():
+ key = f"bars:{frame_type.value}:{code}"
+ for bar in bars:
+ frame = tf.time2int(bar["frame"].item())
+ val = [*bar]
+ val[0] = frame
+ pl.hset(key, frame, ",".join(map(str, val)))
+ await pl.execute()
+
+ @classmethod
+ async def batch_cache_unclosed_bars(
+ cls, frame_type: FrameType, bars: Dict[str, BarsArray]
+ ): # pragma: no cover
+ """缓存未收盘的5、15、30、60分钟线及日线、周线、月线
+
+ Args:
+ frame_type: 帧类型
+ bars: 行情数据,其key为股票代码,其value为dtype为`bars_dtype`的一维numpy数组。bars不能为None,或者empty。
+
+ Raise:
+ RedisError: 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。
+ """
+ pl = cache.security.pipeline()
+ key = f"bars:{frame_type.value}:unclosed"
+
+ convert = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ for code, bar in bars.items():
+ val = [*bar[0]]
+ val[0] = convert(bar["frame"][0].item()) # 时间转换
+ pl.hset(key, code, ",".join(map(str, val)))
+
+ await pl.execute()
+
+ @classmethod
+ async def reset_cache(cls):
+ """清除缓存的行情数据"""
+ try:
+ for ft in itertools.chain(tf.minute_level_frames, tf.day_level_frames):
+ keys = await cache.security.keys(f"bars:{ft.value}:*")
+ if keys:
+ await cache.security.delete(*keys)
+ finally:
+ cls._is_cache_empty = True
+
+ @classmethod
+ def _deserialize_cached_bars(cls, raw: List[str], ft: FrameType) -> BarsArray:
+ """从redis中反序列化缓存的数据
+
+ 如果`raw`空数组或者元素为`None`,则返回空数组。
+
+ Args:
+ raw: redis中的缓存数据
+ ft: 帧类型
+ sort: 是否需要重新排序,缺省为False
+
+ Returns:
+ BarsArray: 行情数据
+ """
+ fix_date = False
+ if ft in tf.minute_level_frames:
+ convert = tf.int2time
+ else:
+ convert = tf.int2date
+ fix_date = True
+ recs = []
+ # it's possible to treat raw as csv and use pandas to parse, however, the performance is 10 times worse than this method
+ for raw_rec in raw:
+ if raw_rec is None:
+ continue
+ f, o, h, l, c, v, m, fac = raw_rec.split(",")
+ if fix_date:
+ f = f[:8]
+ recs.append(
+ (
+ convert(f),
+ float(o),
+ float(h),
+ float(l),
+ float(c),
+ float(v),
+ float(m),
+ float(fac),
+ )
+ )
+
+ return np.array(recs, dtype=bars_dtype)
+
+ @classmethod
+ async def _batch_get_cached_bars_n(
+ cls, frame_type: FrameType, n: int, end: Frame = None, codes: List[str] = None
+ ) -> BarsPanel:
+ """批量获取在cache中截止`end`的`n`个bars。
+
+ 如果`end`不在`frame_type`所属的边界点上,那么,如果`end`大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到`tf.floor(end, frame_type)`。
+
+ Args:
+ frame_type: 时间帧类型
+ n: 返回记录条数
+ codes: 证券代码列表
+ end: 截止时间, 如果为None
+
+ Returns:
+ BarsPanel: 行情数据
+ """
+ # 调用者自己保证end在缓存中
+ cols = list(bars_dtype_with_code.names)
+ if frame_type in tf.day_level_frames:
+ key = f"bars:{frame_type.value}:unclosed"
+ if codes is None:
+ recs = await cache.security.hgetall(key)
+ codes = list(recs.keys())
+ recs = recs.values()
+ else:
+ recs = await cache.security.hmget(key, *codes)
+
+ barss = cls._deserialize_cached_bars(recs, frame_type)
+ if barss.size > 0:
+ if len(barss) != len(codes):
+ # issue 39, 如果某支票当天停牌,则缓存中将不会有它的记录,此时需要移除其代码
+ codes = [
+ codes[i] for i, item in enumerate(recs) if item is not None
+ ]
+ barss = numpy_append_fields(barss, "code", codes, [("code", "O")])
+ return barss[cols].astype(bars_dtype_with_code)
+ else:
+ return np.array([], dtype=bars_dtype_with_code)
+ else:
+ end = end or datetime.datetime.now()
+ close_end = tf.floor(end, frame_type)
+ all_bars = []
+ if codes is None:
+ keys = await cache.security.keys(
+ f"bars:{frame_type.value}:*[^unclosed]"
+ )
+ codes = [key.split(":")[-1] for key in keys]
+ else:
+ keys = [f"bars:{frame_type.value}:{code}" for code in codes]
+
+ if frame_type != FrameType.MIN1:
+ unclosed = await cache.security.hgetall(
+ f"bars:{frame_type.value}:unclosed"
+ )
+ else:
+ unclosed = {}
+
+ pl = cache.security.pipeline()
+ frames = tf.get_frames_by_count(close_end, n, frame_type)
+ for key in keys:
+ pl.hmget(key, *frames)
+
+ all_closed = await pl.execute()
+ for code, raw in zip(codes, all_closed):
+ raw.append(unclosed.get(code))
+ barss = cls._deserialize_cached_bars(raw, frame_type)
+ barss = numpy_append_fields(
+ barss, "code", [code] * len(barss), [("code", "O")]
+ )
+ barss = barss[cols].astype(bars_dtype_with_code)
+ all_bars.append(barss[barss["frame"] <= end][-n:])
+
+ try:
+ return np.concatenate(all_bars)
+ except ValueError as e:
+ logger.exception(e)
+ return np.array([], dtype=bars_dtype_with_code)
+
+ @classmethod
+ async def _get_cached_bars_n(
+ cls, code: str, n: int, frame_type: FrameType, end: Frame = None
+ ) -> BarsArray:
+ """从缓存中获取指定代码的行情数据
+
+ 存取逻辑是,从`end`指定的时间向前取`n`条记录。`end`不应该大于当前系统时间,并且根据`end`和`n`计算出来的起始时间应该在缓存中存在。否则,两种情况下,返回记录数都将小于`n`。
+
+ 如果`end`不处于`frame_type`所属的边界结束位置,且小于当前已缓存的未收盘bar时间,则会返回前一个已收盘的数据,否则,返回的记录中还将包含未收盘的数据。
+
+ args:
+ code: 证券代码,比如000001.XSHE
+ n: 返回记录条数
+ frame_type: 帧类型
+ end: 结束帧,如果为None,则取当前时间
+
+ returns:
+ 元素类型为`coretypes.bars_dtype`的一维numpy数组。如果没有数据,则返回空ndarray。
+ """
+ # 50 times faster than arrow.now().floor('second')
+ end = end or datetime.datetime.now().replace(second=0, microsecond=0)
+
+ if frame_type in tf.minute_level_frames:
+ cache_start = tf.first_min_frame(end.date(), frame_type)
+ closed = tf.floor(end, frame_type)
+
+ frames = (tf.get_frames(cache_start, closed, frame_type))[-n:]
+ if len(frames) == 0:
+ recs = np.empty(shape=(0,), dtype=bars_dtype)
+ else:
+ key = f"bars:{frame_type.value}:{code}"
+ recs = await cache.security.hmget(key, *frames)
+ recs = cls._deserialize_cached_bars(recs, frame_type)
+
+ if closed < end:
+ # for unclosed
+ key = f"bars:{frame_type.value}:unclosed"
+ unclosed = await cache.security.hget(key, code)
+ unclosed = cls._deserialize_cached_bars([unclosed], frame_type)
+
+ if len(unclosed) == 0:
+ return recs[-n:]
+
+ if end < unclosed[0]["frame"].item():
+ # 如果unclosed为9:36, 调用者要求取9:29的5m数据,则取到的unclosed不合要求,抛弃。似乎没有更好的方法检测end与unclosed的关系
+ return recs[-n:]
+ else:
+ bars = np.concatenate((recs, unclosed))
+ return bars[-n:]
+ else:
+ return recs[-n:]
+ else: # 日线及以上级别,仅在缓存中存在未收盘数据
+ key = f"bars:{frame_type.value}:unclosed"
+ rec = await cache.security.hget(key, code)
+ return cls._deserialize_cached_bars([rec], frame_type)
+
+ @classmethod
+ async def cache_bars(cls, code: str, frame_type: FrameType, bars: BarsArray):
+ """将当期已收盘的行情数据缓存
+
+ Note:
+ 当前只缓存1分钟数据。其它分钟数据,都在调用时,通过resample临时合成。
+
+ 行情数据缓存在以`bars:{frame_type.value}:{code}`为key, {frame}为field的hashmap中。
+
+ Args:
+ code: the full qualified code of a security or index
+ frame_type: frame type of the bars
+ bars: the bars to cache, which is a numpy array of dtype `coretypes.bars_dtype`
+
+ Raises:
+ RedisError: if redis operation failed, see documentation of aioredis
+
+ """
+ # 转换时间为int
+ convert = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ key = f"bars:{frame_type.value}:{code}"
+ pl = cache.security.pipeline()
+ for bar in bars:
+ val = [*bar]
+ val[0] = convert(bar["frame"].item())
+ pl.hset(key, val[0], ",".join(map(str, val)))
+
+ await pl.execute()
+
+ @classmethod
+ async def cache_unclosed_bars(
+ cls, code: str, frame_type: FrameType, bars: BarsArray
+ ): # pragma: no cover
+ """将未结束的行情数据缓存
+
+ 未结束的行情数据缓存在以`bars:{frame_type.value}:unclosed`为key, {code}为field的hashmap中。
+
+ 尽管`bars`被声明为BarsArray,但实际上应该只包含一个元素。
+
+ Args:
+ code: the full qualified code of a security or index
+ frame_type: frame type of the bars
+ bars: the bars to cache, which is a numpy array of dtype `coretypes.bars_dtype`
+
+ Raises:
+ RedisError: if redis operation failed, see documentation of aioredis
+
+ """
+ converter = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ assert len(bars) == 1, "unclosed bars should only have one record"
+
+ key = f"bars:{frame_type.value}:unclosed"
+ bar = bars[0]
+ val = [*bar]
+ val[0] = converter(bar["frame"].item())
+ await cache.security.hset(key, code, ",".join(map(str, val)))
+
+ @classmethod
+ async def persist_bars(
+ cls,
+ frame_type: FrameType,
+ bars: Union[Dict[str, BarsArray], BarsArray, pd.DataFrame],
+ ):
+ """将行情数据持久化
+
+ 如果`bars`类型为Dict,则key为`code`,value为`bars`。如果其类型为BarsArray或者pd.DataFrame,则`bars`各列字段应该为`coretypes.bars_dtype` + ("code", "O")构成。
+
+ Args:
+ frame_type: the frame type of the bars
+ bars: the bars to be persisted
+
+ Raises:
+ InfluxDBWriteError: if influxdb write failed
+ """
+ client = get_influx_client()
+
+ measurement = cls._measurement_name(frame_type)
+ logger.info("persisting bars to influxdb: %s, %d secs", measurement, len(bars))
+
+ if isinstance(bars, dict):
+ for code, value in bars.items():
+ await client.save(
+ value, measurement, global_tags={"code": code}, time_key="frame"
+ )
+ else:
+ await client.save(bars, measurement, tag_keys=["code"], time_key="frame")
+
+ @classmethod
+ def resample(
+ cls, bars: BarsArray, from_frame: FrameType, to_frame: FrameType
+ ) -> BarsArray:
+ """将原来为`from_frame`的行情数据转换为`to_frame`的行情数据
+
+ 如果`to_frame`为日线或者分钟级别线,则`from_frame`必须为分钟线;如果`to_frame`为周以上级别线,则`from_frame`必须为日线。其它级别之间的转换不支持。
+
+ 如果`from_frame`为1分钟线,则必须从9:31起。
+
+ Args:
+ bars (BarsArray): 行情数据
+ from_frame (FrameType): 转换前的FrameType
+ to_frame (FrameType): 转换后的FrameType
+
+ Returns:
+ BarsArray: 转换后的行情数据
+ """
+ if from_frame == FrameType.MIN1:
+ return cls._resample_from_min1(bars, to_frame)
+ elif from_frame == FrameType.DAY: # pragma: no cover
+ return cls._resample_from_day(bars, to_frame)
+ else: # pragma: no cover
+ raise TypeError(f"unsupported from_frame: {from_frame}")
+
+ @classmethod
+ def _measurement_name(cls, frame_type):
+ return f"stock_bars_{frame_type.value}"
+
+ @classmethod
+ def _resample_from_min1(cls, bars: BarsArray, to_frame: FrameType) -> BarsArray:
+ """将`bars`从1分钟线转换为`to_frame`的行情数据
+
+ 重采样后的数据只包含frame, open, high, low, close, volume, amount, factor,无论传入数据是否还有别的字段,它们都将被丢弃。
+
+ resampling 240根分钟线到5分钟大约需要100微秒。
+
+ TODO: 如果`bars`中包含nan怎么处理?
+ """
+ if bars[0]["frame"].item().minute != 31:
+ raise ValueError("resampling from 1min must start from 9:31")
+
+ if to_frame not in (
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ FrameType.DAY,
+ ):
+ raise ValueError(f"unsupported to_frame: {to_frame}")
+
+ bins_len = {
+ FrameType.MIN5: 5,
+ FrameType.MIN15: 15,
+ FrameType.MIN30: 30,
+ FrameType.MIN60: 60,
+ FrameType.DAY: 240,
+ }[to_frame]
+
+ bins = len(bars) // bins_len
+ npart1 = bins * bins_len
+
+ part1 = bars[:npart1].reshape((-1, bins_len))
+ part2 = bars[npart1:]
+
+ open_pos = np.arange(bins) * bins_len
+ close_pos = np.arange(1, bins + 1) * bins_len - 1
+ if len(bars) > bins_len * bins:
+ close_pos = np.append(close_pos, len(bars) - 1)
+ resampled = np.empty((bins + 1,), dtype=bars_dtype)
+ else:
+ resampled = np.empty((bins,), dtype=bars_dtype)
+
+ resampled[:bins]["open"] = bars[open_pos]["open"]
+
+ resampled[:bins]["high"] = np.max(part1["high"], axis=1)
+ resampled[:bins]["low"] = np.min(part1["low"], axis=1)
+
+ resampled[:bins]["volume"] = np.sum(part1["volume"], axis=1)
+ resampled[:bins]["amount"] = np.sum(part1["amount"], axis=1)
+
+ if len(part2):
+ resampled[-1]["open"] = part2["open"][0]
+ resampled[-1]["high"] = np.max(part2["high"])
+ resampled[-1]["low"] = np.min(part2["low"])
+
+ resampled[-1]["volume"] = np.sum(part2["volume"])
+ resampled[-1]["amount"] = np.sum(part2["amount"])
+
+ cols = ["frame", "close", "factor"]
+ resampled[cols] = bars[close_pos][cols]
+
+ if to_frame == FrameType.DAY:
+ resampled["frame"] = bars[-1]["frame"].item().date()
+
+ return resampled
+
+ @classmethod
+ def _resample_from_day(cls, bars: BarsArray, to_frame: FrameType) -> BarsArray:
+ """将`bars`从日线转换成`to_frame`的行情数据
+
+ Args:
+ bars (BarsArray): [description]
+ to_frame (FrameType): [description]
+
+ Returns:
+ 转换后的行情数据
+ """
+ rules = {
+ "frame": "last",
+ "open": "first",
+ "high": "max",
+ "low": "min",
+ "close": "last",
+ "volume": "sum",
+ "amount": "sum",
+ "factor": "last",
+ }
+
+ if to_frame == FrameType.WEEK:
+ freq = "W-Fri"
+ elif to_frame == FrameType.MONTH:
+ freq = "M"
+ elif to_frame == FrameType.QUARTER:
+ freq = "Q"
+ elif to_frame == FrameType.YEAR:
+ freq = "A"
+ else:
+ raise ValueError(f"unsupported to_frame: {to_frame}")
+
+ df = pd.DataFrame(bars)
+ df.index = pd.to_datetime(bars["frame"])
+ df = df.resample(freq).agg(rules)
+ bars = np.array(df.to_records(index=False), dtype=bars_dtype)
+
+ # filter out data like (None, nan, ...)
+ return bars[np.isfinite(bars["close"])]
+
+ @classmethod
+ async def _get_price_limit_in_cache(
+ cls, code: str, begin: datetime.date, end: datetime.date
+ ):
+ date_str = await cache._security_.get(TRADE_PRICE_LIMITS_DATE)
+ if date_str:
+ date_in_cache = arrow.get(date_str).date()
+ if date_in_cache < begin or date_in_cache > end:
+ return None
+ else:
+ return None
+
+ dtype = [("frame", "O"), ("high_limit", "f4"), ("low_limit", "f4")]
+ hp = await cache._security_.hget(TRADE_PRICE_LIMITS, f"{code}.high_limit")
+ lp = await cache._security_.hget(TRADE_PRICE_LIMITS, f"{code}.low_limit")
+ if hp is None or lp is None:
+ return None
+ else:
+ return np.array([(date_in_cache, hp, lp)], dtype=dtype)
+
+ @classmethod
+ async def get_trade_price_limits(
+ cls, code: str, begin: Frame, end: Frame
+ ) -> BarsArray:
+ """从influxdb和cache中获取个股在[begin, end]之间的涨跌停价。
+
+ 涨跌停价只有日线数据才有,因此,FrameType固定为FrameType.DAY,
+ 当天的数据存放于redis,如果查询日期包含当天(交易日),从cache中读取并追加到结果中
+
+ Args:
+ code : 个股代码
+ begin : 开始日期
+ end : 结束日期
+
+ Returns:
+ dtype为[('frame', 'O'), ('high_limit', 'f4'), ('low_limit', 'f4')]的numpy数组
+ """
+ cols = ["_time", "high_limit", "low_limit"]
+ dtype = [("frame", "O"), ("high_limit", "f4"), ("low_limit", "f4")]
+
+ if isinstance(begin, datetime.datetime):
+ begin = begin.date() # 强制转换为date
+ if isinstance(end, datetime.datetime):
+ end = end.date() # 强制转换为date
+
+ data_in_cache = await cls._get_price_limit_in_cache(code, begin, end)
+
+ client = get_influx_client()
+ measurement = cls._measurement_name(FrameType.DAY)
+ flux = (
+ Flux()
+ .bucket(client._bucket)
+ .measurement(measurement)
+ .range(begin, end)
+ .tags({"code": code})
+ .fields(cols)
+ .sort("_time")
+ )
+
+ ds = NumpyDeserializer(
+ dtype,
+ use_cols=cols,
+ converters={"_time": lambda x: ciso8601.parse_datetime(x).date()},
+ # since we ask parse date in convertors, so we have to disable parse_date
+ parse_date=None,
+ )
+
+ result = await client.query(flux, ds)
+ if data_in_cache:
+ result = np.concatenate([result, data_in_cache])
+ return result
+
+ @classmethod
+ async def reset_price_limits_cache(cls, cache_only: bool, dt: datetime.date = None):
+ if cache_only is False:
+ date_str = await cache._security_.get(TRADE_PRICE_LIMITS_DATE)
+ if not date_str:
+ return # skip clear action if date not found in cache
+ date_in_cache = arrow.get(date_str).date()
+ if dt is None or date_in_cache != dt: # 更新的时间和cache的时间相同,则清除cache
+ return # skip clear action
+
+ await cache._security_.delete(TRADE_PRICE_LIMITS)
+ await cache._security_.delete(TRADE_PRICE_LIMITS_DATE)
+
+ @classmethod
+ async def save_trade_price_limits(
+ cls, price_limits: LimitPriceOnlyBarsArray, to_cache: bool
+ ):
+ """保存涨跌停价
+
+ Args:
+ price_limits: 要保存的涨跌停价格数据。
+ to_cache: 是保存到缓存中,还是保存到持久化存储中
+ """
+ if len(price_limits) == 0:
+ return
+
+ if to_cache: # 每个交易日上午9点更新两次
+ pl = cache._security_.pipeline()
+ for row in price_limits:
+ # .item convert np.float64 to python float
+ pl.hset(
+ TRADE_PRICE_LIMITS,
+ f"{row['code']}.high_limit",
+ row["high_limit"].item(),
+ )
+ pl.hset(
+ TRADE_PRICE_LIMITS,
+ f"{row['code']}.low_limit",
+ row["low_limit"].item(),
+ )
+
+ dt = price_limits[-1]["frame"]
+ pl.set(TRADE_PRICE_LIMITS_DATE, dt.strftime("%Y-%m-%d"))
+ await pl.execute()
+ else:
+ # to influxdb, 每个交易日的第二天早上2点保存
+ client = get_influx_client()
+ await client.save(
+ price_limits,
+ cls._measurement_name(FrameType.DAY),
+ tag_keys="code",
+ time_key="frame",
+ )
+
+ @classmethod
+ async def trade_price_limit_flags(
+ cls, code: str, start: datetime.date, end: datetime.date
+ ) -> Tuple[List[bool]]:
+ """获取个股在[start, end]之间的涨跌停标志
+
+ !!!Note
+ 本函数返回的序列在股票有停牌的情况下,将不能与[start, end]一一对应。
+
+ Args:
+ code: 个股代码
+ start: 开始日期
+ end: 结束日期
+
+ Returns:
+ 涨跌停标志列表(buy, sell)
+ """
+ cols = ["_time", "close", "high_limit", "low_limit"]
+ client = get_influx_client()
+ measurement = cls._measurement_name(FrameType.DAY)
+ flux = (
+ Flux()
+ .bucket(client._bucket)
+ .measurement(measurement)
+ .range(start, end)
+ .tags({"code": code})
+ .fields(cols)
+ .sort("_time")
+ )
+
+ dtype = [
+ ("frame", "O"),
+ ("close", "f4"),
+ ("high_limit", "f4"),
+ ("low_limit", "f4"),
+ ]
+ ds = NumpyDeserializer(
+ dtype,
+ use_cols=["_time", "close", "high_limit", "low_limit"],
+ converters={"_time": lambda x: ciso8601.parse_datetime(x).date()},
+ # since we ask parse date in convertors, so we have to disable parse_date
+ parse_date=None,
+ )
+
+ result = await client.query(flux, ds)
+ if result.size == 0:
+ return np.array([], dtype=dtype)
+
+ return (
+ array_price_equal(result["close"], result["high_limit"]),
+ array_price_equal(result["close"], result["low_limit"]),
+ )
+
+ @classmethod
+ async def trade_price_limit_flags_ex(
+ cls, code: str, start: datetime.date, end: datetime.date
+ ) -> Dict[datetime.date, Tuple[bool, bool]]:
+ """获取股票`code`在`[start, end]`区间的涨跌停标志
+
+ !!!Note
+ 如果end为当天,注意在未收盘之前,这个涨跌停标志都是不稳定的
+
+ Args:
+ code: 股票代码
+ start: 起始日期
+ end: 结束日期
+
+ Returns:
+ 以日期为key,(涨停,跌停)为值的dict
+ """
+ limit_prices = await cls.get_trade_price_limits(code, start, end)
+ bars = await Stock.get_bars_in_range(
+ code, FrameType.DAY, start=start, end=end, fq=False
+ )
+
+ close = bars["close"]
+
+ results = {}
+
+ # aligned = True
+ for i in range(len(bars)):
+ if bars[i]["frame"].item().date() != limit_prices[i]["frame"]:
+ # aligned = False
+ logger.warning("数据同步错误,涨跌停价格与收盘价时间不一致: %s, %s", code, bars[i]["frame"])
+ break
+
+ results[limit_prices[i]["frame"]] = (
+ price_equal(limit_prices[i]["high_limit"], close[i]),
+ price_equal(limit_prices[i]["low_limit"], close[i]),
+ )
+
+ # if not aligned:
+ # bars = bars[i:]
+ # limit_prices = limit_prices[i:]
+
+ # for frame in bars["frame"]:
+ # frame = frame.item().date()
+ # close = bars[bars["frame"].item().date() == frame]["close"].item()
+ # high = limit_prices[limit_prices["frame"] == frame]["high_limit"].item()
+ # low = limit_prices[limit_prices["frame"] == frame]["low_limit"].item()
+ # results[frame] = (
+ # price_equal(high, close),
+ # price_equal(low, close)
+ # )
+
+ return results
+
+ @classmethod
+ async def get_latest_price(cls, codes: Iterable[str]) -> List[str]:
+ """获取多支股票的最新价格(交易日当天),暂不包括指数
+
+ 价格数据每5秒更新一次,接受多只股票查询,返回最后缓存的价格
+
+ Args:
+ codes: 代码列表
+
+ Returns:
+ 返回一个List,价格是字符形式的浮点数。
+ """
+ if not codes:
+ return []
+
+ _raw_code_list = []
+ for code_str in codes:
+ code, _ = code_str.split(".")
+ _raw_code_list.append(code)
+
+ _converted_data = []
+ raw_data = await cache.feature.hmget(TRADE_LATEST_PRICE, *_raw_code_list)
+ for _data in raw_data:
+ if _data is None:
+ _converted_data.append(_data)
+ else:
+ _converted_data.append(float(_data))
+ return _converted_data
+
security_type: SecurityType
+
+
+ property
+ readonly
+
+
+¶返回证券类型
+ +Returns:
+Type | +Description | +
---|---|
SecurityType |
+ [description] |
+
batch_cache_bars(frame_type, bars)
+
+
+ async
+ classmethod
+
+
+¶缓存已收盘的分钟线和日线
+当缓存日线时,仅限于当日收盘后的第一次同步时调用。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ FrameType |
+ 帧类型 |
+ required | +
bars |
+ Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]] |
+ 行情数据,其key为股票代码,其value为dtype为 |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
RedisError |
+ 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。 |
+
omicron/models/stock.py
@classmethod
+async def batch_cache_bars(cls, frame_type: FrameType, bars: Dict[str, BarsArray]):
+ """缓存已收盘的分钟线和日线
+
+ 当缓存日线时,仅限于当日收盘后的第一次同步时调用。
+
+ Args:
+ frame_type: 帧类型
+ bars: 行情数据,其key为股票代码,其value为dtype为`bars_dtype`的一维numpy数组。
+
+ Raises:
+ RedisError: 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。
+ """
+ if frame_type == FrameType.DAY:
+ await cls.batch_cache_unclosed_bars(frame_type, bars)
+ return
+
+ pl = cache.security.pipeline()
+ for code, bars in bars.items():
+ key = f"bars:{frame_type.value}:{code}"
+ for bar in bars:
+ frame = tf.time2int(bar["frame"].item())
+ val = [*bar]
+ val[0] = frame
+ pl.hset(key, frame, ",".join(map(str, val)))
+ await pl.execute()
+
batch_cache_unclosed_bars(frame_type, bars)
+
+
+ async
+ classmethod
+
+
+¶缓存未收盘的5、15、30、60分钟线及日线、周线、月线
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ FrameType |
+ 帧类型 |
+ required | +
bars |
+ Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]] |
+ 行情数据,其key为股票代码,其value为dtype为 |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
RedisError |
+ 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。 |
+
omicron/models/stock.py
@classmethod
+async def batch_cache_unclosed_bars(
+ cls, frame_type: FrameType, bars: Dict[str, BarsArray]
+): # pragma: no cover
+ """缓存未收盘的5、15、30、60分钟线及日线、周线、月线
+
+ Args:
+ frame_type: 帧类型
+ bars: 行情数据,其key为股票代码,其value为dtype为`bars_dtype`的一维numpy数组。bars不能为None,或者empty。
+
+ Raise:
+ RedisError: 如果在执行过程中发生错误,则抛出以此异常为基类的各种异常,具体参考aioredis相关文档。
+ """
+ pl = cache.security.pipeline()
+ key = f"bars:{frame_type.value}:unclosed"
+
+ convert = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ for code, bar in bars.items():
+ val = [*bar[0]]
+ val[0] = convert(bar["frame"][0].item()) # 时间转换
+ pl.hset(key, code, ",".join(map(str, val)))
+
+ await pl.execute()
+
batch_get_day_level_bars_in_range(codes, frame_type, start, end, fq=True)
+
+
+ classmethod
+
+
+¶获取多支股票(指数)在[start, end)时间段内的行情数据
+如果要获取的行情数据是日线级别(即1d, 1w, 1M),使用本接口。
+停牌数据处理请见get_bars。
+本函数返回一个迭代器,使用方法示例: +
1 +2 |
|
如果end
不在frame_type
所属的边界点上,那么,如果end
大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到tf.floor(end, frame_type)
。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
codes |
+ List[str] |
+ 代码列表 |
+ required | +
frame_type |
+ FrameType |
+ 帧类型 |
+ required | +
start |
+ Union[datetime.date, datetime.datetime] |
+ 起始时间 |
+ required | +
end |
+ Union[datetime.date, datetime.datetime] |
+ 结束时间 |
+ required | +
fq |
+ bool |
+ 是否进行复权,如果是,则进行前复权。Defaults to True. |
+ True |
+
Returns:
+Type | +Description | +
---|---|
Generator[Dict[str, BarsArray], None, None] |
+ 迭代器,每次返回一个字典,其中key为代码,value为行情数据 |
+
omicron/models/stock.py
@classmethod
+async def batch_get_day_level_bars_in_range(
+ cls,
+ codes: List[str],
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame,
+ fq: bool = True,
+) -> Generator[Dict[str, BarsArray], None, None]:
+ """获取多支股票(指数)在[start, end)时间段内的行情数据
+
+ 如果要获取的行情数据是日线级别(即1d, 1w, 1M),使用本接口。
+
+ 停牌数据处理请见[get_bars][omicron.models.stock.Stock.get_bars]。
+
+ 本函数返回一个迭代器,使用方法示例:
+ ```
+ async for code, bars in Stock.batch_get_day_level_bars_in_range(...):
+ print(code, bars)
+ ```
+
+ 如果`end`不在`frame_type`所属的边界点上,那么,如果`end`大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到`tf.floor(end, frame_type)`。
+
+ Args:
+ codes: 代码列表
+ frame_type: 帧类型
+ start: 起始时间
+ end: 结束时间
+ fq: 是否进行复权,如果是,则进行前复权。Defaults to True.
+
+ Returns:
+ Generator[Dict[str, BarsArray], None, None]: 迭代器,每次返回一个字典,其中key为代码,value为行情数据
+ """
+ today = datetime.datetime.now().date()
+ # 日线,end不等于最后交易日,此时已无缓存
+ if frame_type == FrameType.DAY and end == tf.floor(today, frame_type):
+ from_cache = True
+ elif frame_type != FrameType.DAY and start > tf.floor(today, frame_type):
+ from_cache = True
+ else:
+ from_cache = False
+
+ n = tf.count_frames(start, end, frame_type)
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+ batch_size = max(max_query_size // n, 1)
+
+ for i in range(0, len(codes), batch_size):
+ batch_codes = codes[i : i + batch_size]
+ persisted = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, end
+ )
+
+ if from_cache:
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, 1, end, batch_codes
+ )
+ cached = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+
+ df = pd.concat([persisted, cached])
+ else:
+ df = persisted
+
+ for code in batch_codes:
+ filtered = df[df["code"] == code][bars_cols]
+ bars = filtered.to_records(index=False).astype(bars_dtype)
+ if fq:
+ bars = cls.qfq(bars)
+
+ yield code, bars
+
batch_get_min_level_bars_in_range(codes, frame_type, start, end, fq=True)
+
+
+ classmethod
+
+
+¶获取多支股票(指数)在[start, end)时间段内的行情数据
+如果要获取的行情数据是分钟级别(即1m, 5m, 15m, 30m和60m),使用本接口。
+停牌数据处理请见get_bars。
+本函数返回一个迭代器,使用方法示例: +
1 +2 |
|
如果end
不在frame_type
所属的边界点上,那么,如果end
大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到tf.floor(end, frame_type)
。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
codes |
+ List[str] |
+ 股票/指数代码列表 |
+ required | +
frame_type |
+ FrameType |
+ 帧类型 |
+ required | +
start |
+ Union[datetime.date, datetime.datetime] |
+ 起始时间 |
+ required | +
end |
+ Union[datetime.date, datetime.datetime] |
+ 结束时间。如果未指明,则取当前时间。 |
+ required | +
fq |
+ bool |
+ 是否进行复权,如果是,则进行前复权。Defaults to True. |
+ True |
+
Returns:
+Type | +Description | +
---|---|
Generator[Dict[str, BarsArray], None, None] |
+ 迭代器,每次返回一个字典,其中key为代码,value为行情数据 |
+
omicron/models/stock.py
@classmethod
+async def batch_get_min_level_bars_in_range(
+ cls,
+ codes: List[str],
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame,
+ fq: bool = True,
+) -> Generator[Dict[str, BarsArray], None, None]:
+ """获取多支股票(指数)在[start, end)时间段内的行情数据
+
+ 如果要获取的行情数据是分钟级别(即1m, 5m, 15m, 30m和60m),使用本接口。
+
+ 停牌数据处理请见[get_bars][omicron.models.stock.Stock.get_bars]。
+
+ 本函数返回一个迭代器,使用方法示例:
+ ```
+ async for code, bars in Stock.batch_get_min_level_bars_in_range(...):
+ print(code, bars)
+ ```
+
+ 如果`end`不在`frame_type`所属的边界点上,那么,如果`end`大于等于当前缓存未收盘数据时间,则将包含未收盘数据;否则,返回的记录将截止到`tf.floor(end, frame_type)`。
+
+ Args:
+ codes: 股票/指数代码列表
+ frame_type: 帧类型
+ start: 起始时间
+ end: 结束时间。如果未指明,则取当前时间。
+ fq: 是否进行复权,如果是,则进行前复权。Defaults to True.
+
+ Returns:
+ Generator[Dict[str, BarsArray], None, None]: 迭代器,每次返回一个字典,其中key为代码,value为行情数据
+ """
+ closed_end = tf.floor(end, frame_type)
+ n = tf.count_frames(start, closed_end, frame_type)
+ max_query_size = min(cfg.influxdb.max_query_size, INFLUXDB_MAX_QUERY_SIZE)
+ batch_size = max(1, max_query_size // n)
+ ff = tf.first_min_frame(datetime.datetime.now(), frame_type)
+
+ for i in range(0, len(codes), batch_size):
+ batch_codes = codes[i : i + batch_size]
+
+ if end < ff:
+ part1 = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, end
+ )
+ part2 = pd.DataFrame([], columns=bars_dtype_with_code.names)
+ elif start >= ff:
+ part1 = pd.DataFrame([], columns=bars_dtype_with_code.names)
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, n, end, batch_codes
+ )
+ cached = cached[cached["frame"] >= start]
+ part2 = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+ else:
+ part1 = await cls._batch_get_persisted_bars_in_range(
+ batch_codes, frame_type, start, ff
+ )
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ cached = await cls._batch_get_cached_bars_n(
+ frame_type, n, end, batch_codes
+ )
+ part2 = pd.DataFrame(cached, columns=bars_dtype_with_code.names)
+
+ df = pd.concat([part1, part2])
+
+ for code in batch_codes:
+ filtered = df[df["code"] == code][bars_cols]
+ bars = filtered.to_records(index=False).astype(bars_dtype)
+ if fq:
+ bars = cls.qfq(bars)
+
+ yield code, bars
+
cache_bars(code, frame_type, bars)
+
+
+ async
+ classmethod
+
+
+¶将当期已收盘的行情数据缓存
+Note
+当前只缓存1分钟数据。其它分钟数据,都在调用时,通过resample临时合成。
+行情数据缓存在以bars:{frame_type.value}:{code}
为key, {frame}为field的hashmap中。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ the full qualified code of a security or index |
+ required | +
frame_type |
+ FrameType |
+ frame type of the bars |
+ required | +
bars |
+ numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]] |
+ the bars to cache, which is a numpy array of dtype |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
RedisError |
+ if redis operation failed, see documentation of aioredis |
+
omicron/models/stock.py
@classmethod
+async def cache_bars(cls, code: str, frame_type: FrameType, bars: BarsArray):
+ """将当期已收盘的行情数据缓存
+
+ Note:
+ 当前只缓存1分钟数据。其它分钟数据,都在调用时,通过resample临时合成。
+
+ 行情数据缓存在以`bars:{frame_type.value}:{code}`为key, {frame}为field的hashmap中。
+
+ Args:
+ code: the full qualified code of a security or index
+ frame_type: frame type of the bars
+ bars: the bars to cache, which is a numpy array of dtype `coretypes.bars_dtype`
+
+ Raises:
+ RedisError: if redis operation failed, see documentation of aioredis
+
+ """
+ # 转换时间为int
+ convert = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ key = f"bars:{frame_type.value}:{code}"
+ pl = cache.security.pipeline()
+ for bar in bars:
+ val = [*bar]
+ val[0] = convert(bar["frame"].item())
+ pl.hset(key, val[0], ",".join(map(str, val)))
+
+ await pl.execute()
+
cache_unclosed_bars(code, frame_type, bars)
+
+
+ async
+ classmethod
+
+
+¶将未结束的行情数据缓存
+未结束的行情数据缓存在以bars:{frame_type.value}:unclosed
为key, {code}为field的hashmap中。
尽管bars
被声明为BarsArray,但实际上应该只包含一个元素。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ the full qualified code of a security or index |
+ required | +
frame_type |
+ FrameType |
+ frame type of the bars |
+ required | +
bars |
+ numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]] |
+ the bars to cache, which is a numpy array of dtype |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
RedisError |
+ if redis operation failed, see documentation of aioredis |
+
omicron/models/stock.py
@classmethod
+async def cache_unclosed_bars(
+ cls, code: str, frame_type: FrameType, bars: BarsArray
+): # pragma: no cover
+ """将未结束的行情数据缓存
+
+ 未结束的行情数据缓存在以`bars:{frame_type.value}:unclosed`为key, {code}为field的hashmap中。
+
+ 尽管`bars`被声明为BarsArray,但实际上应该只包含一个元素。
+
+ Args:
+ code: the full qualified code of a security or index
+ frame_type: frame type of the bars
+ bars: the bars to cache, which is a numpy array of dtype `coretypes.bars_dtype`
+
+ Raises:
+ RedisError: if redis operation failed, see documentation of aioredis
+
+ """
+ converter = tf.time2int if frame_type in tf.minute_level_frames else tf.date2int
+
+ assert len(bars) == 1, "unclosed bars should only have one record"
+
+ key = f"bars:{frame_type.value}:unclosed"
+ bar = bars[0]
+ val = [*bar]
+ val[0] = converter(bar["frame"].item())
+ await cache.security.hset(key, code, ",".join(map(str, val)))
+
days_since_ipo(self)
+
+
+¶获取上市以来经过了多少个交易日
+由于受交易日历限制(2005年1月4日之前的交易日历没有),对于在之前上市的品种,都返回从2005年1月4日起的日期。
+ +Returns:
+Type | +Description | +
---|---|
int |
+ [description] |
+
omicron/models/stock.py
def days_since_ipo(self) -> int:
+ """获取上市以来经过了多少个交易日
+
+ 由于受交易日历限制(2005年1月4日之前的交易日历没有),对于在之前上市的品种,都返回从2005年1月4日起的日期。
+
+ Returns:
+ int: [description]
+ """
+ epoch_start = arrow.get("2005-01-04").date()
+ ipo_day = self.ipo_date if self.ipo_date > epoch_start else epoch_start
+ return tf.count_day_frames(ipo_day, arrow.now().date())
+
format_code(code)
+
+
+ staticmethod
+
+
+¶新三板和北交所的股票, 暂不支持, 默认返回None +上证A股: 600、601、603、605 +深证A股: 000、001 +中小板: 002、003 +创业板: 300/301 +科创板: 688 +新三板: 82、83、87、88、430、420、400 +北交所: 43、83、87、88
+ +omicron/models/stock.py
@staticmethod
+def format_code(code) -> str:
+ """新三板和北交所的股票, 暂不支持, 默认返回None
+ 上证A股: 600、601、603、605
+ 深证A股: 000、001
+ 中小板: 002、003
+ 创业板: 300/301
+ 科创板: 688
+ 新三板: 82、83、87、88、430、420、400
+ 北交所: 43、83、87、88
+ """
+ if not code or len(code) != 6:
+ return None
+
+ prefix = code[0]
+ if prefix in ("0", "3"):
+ return f"{code}.XSHE"
+ elif prefix == "6":
+ return f"{code}.XSHG"
+ else:
+ return None
+
fuzzy_match(query)
+
+
+ classmethod
+
+
+¶对股票/指数进行模糊匹配查找
+query可以是股票/指数代码,也可以是字母(按name查找),也可以是汉字(按显示名查找)
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
query |
+ str |
+ 查询字符串 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Dict[str, Tuple] |
+ 查询结果,其中Tuple为(code, display_name, name, start, end, type) |
+
omicron/models/stock.py
@classmethod
+def fuzzy_match(cls, query: str) -> Dict[str, Tuple]:
+ """对股票/指数进行模糊匹配查找
+
+ query可以是股票/指数代码,也可以是字母(按name查找),也可以是汉字(按显示名查找)
+
+ Args:
+ query (str): 查询字符串
+
+ Returns:
+ Dict[str, Tuple]: 查询结果,其中Tuple为(code, display_name, name, start, end, type)
+ """
+ query = query.upper()
+ if re.match(r"\d+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["code"].startswith(query)
+ }
+ elif re.match(r"[A-Z]+", query):
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["name"].startswith(query)
+ }
+ else:
+ return {
+ sec["code"]: sec.tolist()
+ for sec in cls._stocks
+ if sec["alias"].find(query) != -1
+ }
+
get_bars(code, n, frame_type, end=None, fq=True, unclosed=True)
+
+
+ async
+ classmethod
+
+
+¶获取到end
为止的n
个行情数据。
返回的数据是按照时间顺序递增排序的。在遇到停牌的情况时,该时段数据将被跳过,因此返回的记录可能不是交易日连续的,并且可能不足n
个。
如果系统当前没有到指定时间end
的数据,将尽最大努力返回数据。调用者可以通过判断最后一条数据的时间是否等于end
来判断是否获取到了全部数据。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ 证券代码 |
+ required | +
n |
+ int |
+ 记录数 |
+ required | +
frame_type |
+ FrameType |
+ 帧类型 |
+ required | +
end |
+ Union[datetime.date, datetime.datetime] |
+ 截止时间,如果未指明,则取当前时间 |
+ None |
+
fq |
+ + | 是否对返回记录进行复权。如果为 |
+ True |
+
unclosed |
+ + | 是否包含最新未收盘的数据? Defaults to True. |
+ True |
+
Returns:
+Type | +Description | +
---|---|
numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]] |
+ 返回dtype为 |
+
omicron/models/stock.py
@classmethod
+async def get_bars(
+ cls,
+ code: str,
+ n: int,
+ frame_type: FrameType,
+ end: Frame = None,
+ fq=True,
+ unclosed=True,
+) -> BarsArray:
+ """获取到`end`为止的`n`个行情数据。
+
+ 返回的数据是按照时间顺序递增排序的。在遇到停牌的情况时,该时段数据将被跳过,因此返回的记录可能不是交易日连续的,并且可能不足`n`个。
+
+ 如果系统当前没有到指定时间`end`的数据,将尽最大努力返回数据。调用者可以通过判断最后一条数据的时间是否等于`end`来判断是否获取到了全部数据。
+
+ Args:
+ code: 证券代码
+ n: 记录数
+ frame_type: 帧类型
+ end: 截止时间,如果未指明,则取当前时间
+ fq: 是否对返回记录进行复权。如果为`True`的话,则进行前复权。Defaults to True.
+ unclosed: 是否包含最新未收盘的数据? Defaults to True.
+
+ Returns:
+ 返回dtype为`coretypes.bars_dtype`的一维numpy数组。
+ """
+ now = datetime.datetime.now()
+ try:
+ cached = np.array([], dtype=bars_dtype)
+
+ if frame_type in tf.day_level_frames:
+ if end is None:
+ end = now.date()
+ elif type(end) == datetime.datetime:
+ end = end.date()
+ n0 = n
+ if unclosed:
+ cached = await cls._get_cached_bars_n(code, 1, frame_type)
+ if cached.size > 0:
+ # 如果缓存的未收盘日期 > end,则该缓存不是需要的
+ if cached[0]["frame"].item().date() > end:
+ cached = np.array([], dtype=bars_dtype)
+ else:
+ n0 = n - 1
+ else:
+ end = end or now
+ closed_frame = tf.floor(end, frame_type)
+
+ # fetch one more bar, in case we should discard unclosed bar
+ cached = await cls._get_cached_bars_n(code, n + 1, frame_type, end)
+ if not unclosed:
+ cached = cached[cached["frame"] <= closed_frame]
+
+ # n bars we need fetch from persisted db
+ n0 = n - cached.size
+ if n0 > 0:
+ if cached.size > 0:
+ end0 = cached[0]["frame"].item()
+ else:
+ end0 = end
+
+ bars = await cls._get_persisted_bars_n(code, frame_type, n0, end0)
+ merged = np.concatenate((bars, cached))
+ bars = merged[-n:]
+ else:
+ bars = cached[-n:]
+
+ if fq:
+ bars = cls.qfq(bars)
+ return bars
+ except Exception as e:
+ logger.exception(e)
+ logger.warning(
+ "failed to get bars for %s, %s, %s, %s", code, n, frame_type, end
+ )
+ raise
+
get_bars_in_range(code, frame_type, start, end=None, fq=True, unclosed=True)
+
+
+ async
+ classmethod
+
+
+¶获取指定证券(code
)在[start
, end
]期间帧类型为frame_type
的行情数据。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ + | 证券代码 |
+ required | +
frame_type |
+ + | 行情数据的帧类型 |
+ required | +
start |
+ + | 起始时间 |
+ required | +
end |
+ + | 结束时间,如果为None,则表明取到当前时间。 |
+ None |
+
fq |
+ + | 是否对行情数据执行前复权操作 |
+ True |
+
unclosed |
+ + | 是否包含未收盘的数据 |
+ True |
+
omicron/models/stock.py
@classmethod
+async def get_bars_in_range(
+ cls,
+ code: str,
+ frame_type: FrameType,
+ start: Frame,
+ end: Frame = None,
+ fq=True,
+ unclosed=True,
+) -> BarsArray:
+ """获取指定证券(`code`)在[`start`, `end`]期间帧类型为`frame_type`的行情数据。
+
+ Args:
+ code : 证券代码
+ frame_type : 行情数据的帧类型
+ start : 起始时间
+ end : 结束时间,如果为None,则表明取到当前时间。
+ fq : 是否对行情数据执行前复权操作
+ unclosed : 是否包含未收盘的数据
+ """
+ now = datetime.datetime.now()
+
+ if frame_type in tf.day_level_frames:
+ end = end or now.date()
+ if unclosed and tf.day_shift(end, 0) == now.date():
+ part2 = await cls._get_cached_bars_n(code, 1, frame_type)
+ else:
+ part2 = np.array([], dtype=bars_dtype)
+
+ # get rest from persisted
+ part1 = await cls._get_persisted_bars_in_range(code, frame_type, start, end)
+ bars = np.concatenate((part1, part2))
+ else:
+ end = end or now
+ closed_end = tf.floor(end, frame_type)
+ ff_min1 = tf.first_min_frame(now, FrameType.MIN1)
+ if tf.day_shift(end, 0) < now.date() or end < ff_min1:
+ part1 = await cls._get_persisted_bars_in_range(
+ code, frame_type, start, end
+ )
+ part2 = np.array([], dtype=bars_dtype)
+ elif start >= ff_min1: # all in cache
+ part1 = np.array([], dtype=bars_dtype)
+ n = tf.count_frames(start, closed_end, frame_type) + 1
+ part2 = await cls._get_cached_bars_n(code, n, frame_type, end)
+ part2 = part2[part2["frame"] >= start]
+ else: # in both cache and persisted
+ ff = tf.first_min_frame(now, frame_type)
+ part1 = await cls._get_persisted_bars_in_range(
+ code, frame_type, start, ff
+ )
+ n = tf.count_frames(ff, closed_end, frame_type) + 1
+ part2 = await cls._get_cached_bars_n(code, n, frame_type, end)
+
+ if not unclosed:
+ part2 = part2[part2["frame"] <= closed_end]
+ bars = np.concatenate((part1, part2))
+
+ if fq:
+ return cls.qfq(bars)
+ else:
+ return bars
+
get_latest_price(codes)
+
+
+ async
+ classmethod
+
+
+¶获取多支股票的最新价格(交易日当天),暂不包括指数
+价格数据每5秒更新一次,接受多只股票查询,返回最后缓存的价格
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
codes |
+ Iterable[str] |
+ 代码列表 |
+ required | +
Returns:
+Type | +Description | +
---|---|
List[str] |
+ 返回一个List,价格是字符形式的浮点数。 |
+
omicron/models/stock.py
@classmethod
+async def get_latest_price(cls, codes: Iterable[str]) -> List[str]:
+ """获取多支股票的最新价格(交易日当天),暂不包括指数
+
+ 价格数据每5秒更新一次,接受多只股票查询,返回最后缓存的价格
+
+ Args:
+ codes: 代码列表
+
+ Returns:
+ 返回一个List,价格是字符形式的浮点数。
+ """
+ if not codes:
+ return []
+
+ _raw_code_list = []
+ for code_str in codes:
+ code, _ = code_str.split(".")
+ _raw_code_list.append(code)
+
+ _converted_data = []
+ raw_data = await cache.feature.hmget(TRADE_LATEST_PRICE, *_raw_code_list)
+ for _data in raw_data:
+ if _data is None:
+ _converted_data.append(_data)
+ else:
+ _converted_data.append(float(_data))
+ return _converted_data
+
get_trade_price_limits(code, begin, end)
+
+
+ async
+ classmethod
+
+
+¶从influxdb和cache中获取个股在[begin, end]之间的涨跌停价。
+涨跌停价只有日线数据才有,因此,FrameType固定为FrameType.DAY, +当天的数据存放于redis,如果查询日期包含当天(交易日),从cache中读取并追加到结果中
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ + | 个股代码 |
+ required | +
begin |
+ + | 开始日期 |
+ required | +
end |
+ + | 结束日期 |
+ required | +
Returns:
+Type | +Description | +
---|---|
numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]] |
+ dtype为[('frame', 'O'), ('high_limit', 'f4'), ('low_limit', 'f4')]的numpy数组 |
+
omicron/models/stock.py
@classmethod
+async def get_trade_price_limits(
+ cls, code: str, begin: Frame, end: Frame
+) -> BarsArray:
+ """从influxdb和cache中获取个股在[begin, end]之间的涨跌停价。
+
+ 涨跌停价只有日线数据才有,因此,FrameType固定为FrameType.DAY,
+ 当天的数据存放于redis,如果查询日期包含当天(交易日),从cache中读取并追加到结果中
+
+ Args:
+ code : 个股代码
+ begin : 开始日期
+ end : 结束日期
+
+ Returns:
+ dtype为[('frame', 'O'), ('high_limit', 'f4'), ('low_limit', 'f4')]的numpy数组
+ """
+ cols = ["_time", "high_limit", "low_limit"]
+ dtype = [("frame", "O"), ("high_limit", "f4"), ("low_limit", "f4")]
+
+ if isinstance(begin, datetime.datetime):
+ begin = begin.date() # 强制转换为date
+ if isinstance(end, datetime.datetime):
+ end = end.date() # 强制转换为date
+
+ data_in_cache = await cls._get_price_limit_in_cache(code, begin, end)
+
+ client = get_influx_client()
+ measurement = cls._measurement_name(FrameType.DAY)
+ flux = (
+ Flux()
+ .bucket(client._bucket)
+ .measurement(measurement)
+ .range(begin, end)
+ .tags({"code": code})
+ .fields(cols)
+ .sort("_time")
+ )
+
+ ds = NumpyDeserializer(
+ dtype,
+ use_cols=cols,
+ converters={"_time": lambda x: ciso8601.parse_datetime(x).date()},
+ # since we ask parse date in convertors, so we have to disable parse_date
+ parse_date=None,
+ )
+
+ result = await client.query(flux, ds)
+ if data_in_cache:
+ result = np.concatenate([result, data_in_cache])
+ return result
+
persist_bars(frame_type, bars)
+
+
+ async
+ classmethod
+
+
+¶将行情数据持久化
+如果bars
类型为Dict,则key为code
,value为bars
。如果其类型为BarsArray或者pd.DataFrame,则bars
各列字段应该为coretypes.bars_dtype
+ ("code", "O")构成。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ FrameType |
+ the frame type of the bars |
+ required | +
bars |
+ Union[Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]], numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]], pandas.core.frame.DataFrame] |
+ the bars to be persisted |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
InfluxDBWriteError |
+ if influxdb write failed |
+
omicron/models/stock.py
@classmethod
+async def persist_bars(
+ cls,
+ frame_type: FrameType,
+ bars: Union[Dict[str, BarsArray], BarsArray, pd.DataFrame],
+):
+ """将行情数据持久化
+
+ 如果`bars`类型为Dict,则key为`code`,value为`bars`。如果其类型为BarsArray或者pd.DataFrame,则`bars`各列字段应该为`coretypes.bars_dtype` + ("code", "O")构成。
+
+ Args:
+ frame_type: the frame type of the bars
+ bars: the bars to be persisted
+
+ Raises:
+ InfluxDBWriteError: if influxdb write failed
+ """
+ client = get_influx_client()
+
+ measurement = cls._measurement_name(frame_type)
+ logger.info("persisting bars to influxdb: %s, %d secs", measurement, len(bars))
+
+ if isinstance(bars, dict):
+ for code, value in bars.items():
+ await client.save(
+ value, measurement, global_tags={"code": code}, time_key="frame"
+ )
+ else:
+ await client.save(bars, measurement, tag_keys=["code"], time_key="frame")
+
qfq(bars)
+
+
+ staticmethod
+
+
+¶对行情数据执行前复权操作
+ +omicron/models/stock.py
@staticmethod
+def qfq(bars: BarsArray) -> BarsArray:
+ """对行情数据执行前复权操作"""
+ # todo: 这里可以优化
+ if bars.size == 0:
+ return bars
+
+ last = bars[-1]["factor"]
+ for field in ["open", "high", "low", "close", "volume"]:
+ bars[field] = bars[field] * (bars["factor"] / last)
+
+ return bars
+
resample(bars, from_frame, to_frame)
+
+
+ classmethod
+
+
+¶将原来为from_frame
的行情数据转换为to_frame
的行情数据
如果to_frame
为日线或者分钟级别线,则from_frame
必须为分钟线;如果to_frame
为周以上级别线,则from_frame
必须为日线。其它级别之间的转换不支持。
如果from_frame
为1分钟线,则必须从9:31起。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
bars |
+ BarsArray |
+ 行情数据 |
+ required | +
from_frame |
+ FrameType |
+ 转换前的FrameType |
+ required | +
to_frame |
+ FrameType |
+ 转换后的FrameType |
+ required | +
Returns:
+Type | +Description | +
---|---|
BarsArray |
+ 转换后的行情数据 |
+
omicron/models/stock.py
@classmethod
+def resample(
+ cls, bars: BarsArray, from_frame: FrameType, to_frame: FrameType
+) -> BarsArray:
+ """将原来为`from_frame`的行情数据转换为`to_frame`的行情数据
+
+ 如果`to_frame`为日线或者分钟级别线,则`from_frame`必须为分钟线;如果`to_frame`为周以上级别线,则`from_frame`必须为日线。其它级别之间的转换不支持。
+
+ 如果`from_frame`为1分钟线,则必须从9:31起。
+
+ Args:
+ bars (BarsArray): 行情数据
+ from_frame (FrameType): 转换前的FrameType
+ to_frame (FrameType): 转换后的FrameType
+
+ Returns:
+ BarsArray: 转换后的行情数据
+ """
+ if from_frame == FrameType.MIN1:
+ return cls._resample_from_min1(bars, to_frame)
+ elif from_frame == FrameType.DAY: # pragma: no cover
+ return cls._resample_from_day(bars, to_frame)
+ else: # pragma: no cover
+ raise TypeError(f"unsupported from_frame: {from_frame}")
+
reset_cache()
+
+
+ async
+ classmethod
+
+
+¶清除缓存的行情数据
+ +omicron/models/stock.py
@classmethod
+async def reset_cache(cls):
+ """清除缓存的行情数据"""
+ try:
+ for ft in itertools.chain(tf.minute_level_frames, tf.day_level_frames):
+ keys = await cache.security.keys(f"bars:{ft.value}:*")
+ if keys:
+ await cache.security.delete(*keys)
+ finally:
+ cls._is_cache_empty = True
+
save_trade_price_limits(price_limits, to_cache)
+
+
+ async
+ classmethod
+
+
+¶保存涨跌停价
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
price_limits |
+ numpy.ndarray[Any, numpy.dtype[dtype([('frame', 'O'), ('code', 'O'), ('high_limit', '<f4'), ('low_limit', '<f4')])]] |
+ 要保存的涨跌停价格数据。 |
+ required | +
to_cache |
+ bool |
+ 是保存到缓存中,还是保存到持久化存储中 |
+ required | +
omicron/models/stock.py
@classmethod
+async def save_trade_price_limits(
+ cls, price_limits: LimitPriceOnlyBarsArray, to_cache: bool
+):
+ """保存涨跌停价
+
+ Args:
+ price_limits: 要保存的涨跌停价格数据。
+ to_cache: 是保存到缓存中,还是保存到持久化存储中
+ """
+ if len(price_limits) == 0:
+ return
+
+ if to_cache: # 每个交易日上午9点更新两次
+ pl = cache._security_.pipeline()
+ for row in price_limits:
+ # .item convert np.float64 to python float
+ pl.hset(
+ TRADE_PRICE_LIMITS,
+ f"{row['code']}.high_limit",
+ row["high_limit"].item(),
+ )
+ pl.hset(
+ TRADE_PRICE_LIMITS,
+ f"{row['code']}.low_limit",
+ row["low_limit"].item(),
+ )
+
+ dt = price_limits[-1]["frame"]
+ pl.set(TRADE_PRICE_LIMITS_DATE, dt.strftime("%Y-%m-%d"))
+ await pl.execute()
+ else:
+ # to influxdb, 每个交易日的第二天早上2点保存
+ client = get_influx_client()
+ await client.save(
+ price_limits,
+ cls._measurement_name(FrameType.DAY),
+ tag_keys="code",
+ time_key="frame",
+ )
+
trade_price_limit_flags(code, start, end)
+
+
+ async
+ classmethod
+
+
+¶获取个股在[start, end]之间的涨跌停标志
+Note
+本函数返回的序列在股票有停牌的情况下,将不能与[start, end]一一对应。
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ 个股代码 |
+ required | +
start |
+ date |
+ 开始日期 |
+ required | +
end |
+ date |
+ 结束日期 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple[List[bool]] |
+ 涨跌停标志列表(buy, sell) |
+
omicron/models/stock.py
@classmethod
+async def trade_price_limit_flags(
+ cls, code: str, start: datetime.date, end: datetime.date
+) -> Tuple[List[bool]]:
+ """获取个股在[start, end]之间的涨跌停标志
+
+ !!!Note
+ 本函数返回的序列在股票有停牌的情况下,将不能与[start, end]一一对应。
+
+ Args:
+ code: 个股代码
+ start: 开始日期
+ end: 结束日期
+
+ Returns:
+ 涨跌停标志列表(buy, sell)
+ """
+ cols = ["_time", "close", "high_limit", "low_limit"]
+ client = get_influx_client()
+ measurement = cls._measurement_name(FrameType.DAY)
+ flux = (
+ Flux()
+ .bucket(client._bucket)
+ .measurement(measurement)
+ .range(start, end)
+ .tags({"code": code})
+ .fields(cols)
+ .sort("_time")
+ )
+
+ dtype = [
+ ("frame", "O"),
+ ("close", "f4"),
+ ("high_limit", "f4"),
+ ("low_limit", "f4"),
+ ]
+ ds = NumpyDeserializer(
+ dtype,
+ use_cols=["_time", "close", "high_limit", "low_limit"],
+ converters={"_time": lambda x: ciso8601.parse_datetime(x).date()},
+ # since we ask parse date in convertors, so we have to disable parse_date
+ parse_date=None,
+ )
+
+ result = await client.query(flux, ds)
+ if result.size == 0:
+ return np.array([], dtype=dtype)
+
+ return (
+ array_price_equal(result["close"], result["high_limit"]),
+ array_price_equal(result["close"], result["low_limit"]),
+ )
+
trade_price_limit_flags_ex(code, start, end)
+
+
+ async
+ classmethod
+
+
+¶获取股票code
在[start, end]
区间的涨跌停标志
Note
+如果end为当天,注意在未收盘之前,这个涨跌停标志都是不稳定的
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
code |
+ str |
+ 股票代码 |
+ required | +
start |
+ date |
+ 起始日期 |
+ required | +
end |
+ date |
+ 结束日期 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Dict[datetime.date, Tuple[bool, bool]] |
+ 以日期为key,(涨停,跌停)为值的dict |
+
omicron/models/stock.py
@classmethod
+async def trade_price_limit_flags_ex(
+ cls, code: str, start: datetime.date, end: datetime.date
+) -> Dict[datetime.date, Tuple[bool, bool]]:
+ """获取股票`code`在`[start, end]`区间的涨跌停标志
+
+ !!!Note
+ 如果end为当天,注意在未收盘之前,这个涨跌停标志都是不稳定的
+
+ Args:
+ code: 股票代码
+ start: 起始日期
+ end: 结束日期
+
+ Returns:
+ 以日期为key,(涨停,跌停)为值的dict
+ """
+ limit_prices = await cls.get_trade_price_limits(code, start, end)
+ bars = await Stock.get_bars_in_range(
+ code, FrameType.DAY, start=start, end=end, fq=False
+ )
+
+ close = bars["close"]
+
+ results = {}
+
+ # aligned = True
+ for i in range(len(bars)):
+ if bars[i]["frame"].item().date() != limit_prices[i]["frame"]:
+ # aligned = False
+ logger.warning("数据同步错误,涨跌停价格与收盘价时间不一致: %s, %s", code, bars[i]["frame"])
+ break
+
+ results[limit_prices[i]["frame"]] = (
+ price_equal(limit_prices[i]["high_limit"], close[i]),
+ price_equal(limit_prices[i]["low_limit"], close[i]),
+ )
+
+ # if not aligned:
+ # bars = bars[i:]
+ # limit_prices = limit_prices[i:]
+
+ # for frame in bars["frame"]:
+ # frame = frame.item().date()
+ # close = bars[bars["frame"].item().date() == frame]["close"].item()
+ # high = limit_prices[limit_prices["frame"] == frame]["high_limit"].item()
+ # low = limit_prices[limit_prices["frame"] == frame]["low_limit"].item()
+ # results[frame] = (
+ # price_equal(high, close),
+ # price_equal(low, close)
+ # )
+
+ return results
+
base
+
+
+
+¶
+BacktestState
+
+
+
+ dataclass
+
+
+¶BacktestState(start: Union[datetime.date, datetime.datetime], end: Union[datetime.date, datetime.datetime], barss: Union[NoneType, Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]]], cursor: int, warmup_peroid: int, baseline: str = '399300.XSHE')
+ +omicron/strategy/base.py
class BacktestState(object):
+ start: Frame
+ end: Frame
+ barss: Union[None, Dict[str, BarsArray]]
+ cursor: int
+ warmup_peroid: int
+ baseline: str = "399300.XSHE"
+
+BaseStrategy
+
+
+
+¶omicron/strategy/base.py
class BaseStrategy:
+ def __init__(
+ self,
+ url: str,
+ account: Optional[str] = None,
+ token: Optional[str] = None,
+ name: Optional[str] = None,
+ ver: Optional[str] = None,
+ is_backtest: bool = True,
+ start: Optional[Frame] = None,
+ end: Optional[Frame] = None,
+ frame_type: Optional[FrameType] = None,
+ warmup_period: int = 0,
+ ):
+ """构造函数
+
+ Args:
+ url: 实盘/回测服务器的地址。
+ start: 回测起始日期。回测模式下必须传入。
+ end: 回测结束日期。回测模式下必须传入。
+ account: 实盘/回测账号。实盘模式下必须传入。在回测模式下,如果未传入,将以策略名+随机字符构建账号。
+ token: 实盘/回测时用的token。实盘模式下必须传入。在回测模式下,如果未传入,将自动生成。
+ is_backtest: 是否为回测模式。
+ name: 策略名。如果不传入,则使用类名字小写
+ ver: 策略版本号。如果不传入,则默认为0.1.
+ start: 如果是回测模式,则需要提供回测起始时间
+ end: 如果是回测模式,则需要提供回测结束时间
+ frame_type: 如果是回测模式,则需要提供回测时使用的主周期
+ warmup_period: 策略执行时需要的最小bar数(以frame_type)计。
+ """
+ self.ver = ver or "0.1"
+ self.name = name or self.__class__.__name__.lower() + f"_v{self.ver}"
+
+ self.token = token or uuid.uuid4().hex
+ self.account = account or f"smallcap-{self.token[-4:]}"
+
+ self.url = url
+ self.bills = None
+ self.metrics = None
+
+ # used by both live and backtest
+ self.warmup_period = warmup_period
+ self.is_backtest = is_backtest
+ if is_backtest:
+ if start is None or end is None or frame_type is None:
+ raise ValueError("start, end and frame_type must be presented.")
+
+ self.bs = BacktestState(start, end, None, 0, warmup_period)
+ self._frame_type = frame_type
+ self.broker = TraderClient(
+ url,
+ self.account,
+ self.token,
+ is_backtest=True,
+ start=self.bs.start,
+ end=self.bs.end,
+ )
+ else:
+ if account is None or token is None:
+ raise ValueError("account and token must be presented.")
+
+ self.broker = TraderClient(url, self.account, self.token, is_backtest=False)
+
+ async def _cache_bars_for_backtest(self, portfolio: List[str], n: int):
+ if portfolio is None or len(portfolio) == 0:
+ return
+
+ count = tf.count_frames(self.bs.start, self.bs.end, self._frame_type)
+ tasks = [
+ Stock.get_bars(code, count + n, self._frame_type, self.bs.end, fq=False)
+ for code in portfolio
+ ]
+
+ results = await gather(*tasks)
+ self.bs.barss = {k: v for (k, v) in zip(portfolio, results)}
+
+ def _next(self):
+ if self.bs.barss is None:
+ return None
+
+ self.bs.cursor += 1
+ return {
+ k: Stock.qfq(v[self.bs.cursor - self.bs.warmup_peroid : self.bs.cursor])
+ for (k, v) in self.bs.barss.items()
+ }
+
+ async def peek(self, code: str, n: int):
+ """允许策略偷看未来数据
+
+ 可用以因子检验场景。要求数据本身已缓存。否则请用Stock.get_bars等方法获取。
+ """
+ if self.bs is None or self.bs.barss is None:
+ raise ValueError("data is not cached")
+
+ if code in self.bs.barss:
+ if self.bs.cursor + n + 1 < len(self.bs.barss[code]):
+ return Stock.qfq(
+ self.bs.barss[code][self.bs.cursor : self.bs.cursor + n]
+ )
+
+ else:
+ raise ValueError("data is not cached")
+
+ async def backtest(self, stop_on_error: bool = True, **kwargs):
+ """执行回测
+
+ Args:
+ stop_on_error: 如果为True,则发生异常时,将停止回测。否则忽略错误,继续执行。
+ Keyword Args:
+ prefetch_stocks Dict[str, BarsArray]: 代码列表。在该列表中的品种,将在回测之前自动预取行情数据,并在调用predict时,传入截止到当前frame的,长度为n的行情数据。行情周期由构造时的frame_type指定。预取数据长度由`self.warmup_period`决定
+ """
+ prefetch_stocks: List[str] = kwargs.get("prefetch_stocks") # type: ignore
+ await self._cache_bars_for_backtest(prefetch_stocks, self.warmup_period)
+ self.bs.cursor = self.warmup_period
+
+ intra_day = self._frame_type in tf.minute_level_frames
+ converter = tf.int2time if intra_day else tf.int2date
+
+ await self.before_start()
+
+ # 最后一周期不做预测,留出来执行上一周期的信号
+ end_ = tf.shift(self.bs.end, -1, self._frame_type)
+ for i, frame in enumerate(
+ tf.get_frames(self.bs.start, end_, self._frame_type) # type: ignore
+ ):
+ barss = self._next()
+ day_barss = barss if self._frame_type == FrameType.DAY else None
+ frame_ = converter(frame)
+
+ prev_frame = tf.shift(frame_, -1, self._frame_type)
+ next_frame = tf.shift(frame_, 1, self._frame_type)
+
+ # new trading day start
+ if (not intra_day and prev_frame < frame_) or (
+ intra_day and prev_frame.date() < frame_.date()
+ ):
+ await self.before_trade(frame_, day_barss)
+
+ logger.debug("%sth iteration", i, date=frame_)
+ try:
+ await self.predict(
+ frame_, self._frame_type, i, barss=barss, **kwargs # type: ignore
+ )
+ except Exception as e:
+ if isinstance(e, TradeError):
+ logger.warning("call stack is:\n%s", e.stack)
+ else:
+ logger.exception(e)
+ if stop_on_error:
+ raise e
+
+ # trading day ends
+ if (not intra_day and next_frame > frame_) or (
+ intra_day and next_frame.date() > frame_.date()
+ ):
+ await self.after_trade(frame_, day_barss)
+
+ self.broker.stop_backtest()
+
+ await self.after_stop()
+ self.bills = self.broker.bills()
+ baseline = kwargs.get("baseline", "399300.XSHE")
+ self.metrics = self.broker.metrics(baseline=baseline)
+ self.bs.baseline = baseline
+
+ @property
+ def cash(self):
+ """返回当前可用现金"""
+ return self.broker.available_money
+
+ def positions(self, dt: Optional[datetime.date] = None):
+ """返回当前持仓"""
+ return self.broker.positions(dt)
+
+ def available_shares(self, sec: str, dt: Optional[Frame] = None):
+ """返回给定股票在`dt`日的可售股数
+
+ Args:
+ sec: 证券代码
+ dt: 日期,在实盘中无意义,只能返回最新数据;在回测时,必须指定日期,且返回指定日期下的持仓。
+ """
+ return self.broker.available_shares(sec, dt)
+
+ async def buy(
+ self,
+ sec: str,
+ price: Optional[float] = None,
+ vol: Optional[int] = None,
+ money: Optional[float] = None,
+ order_time: Optional[datetime.datetime] = None,
+ ) -> Dict:
+ """买入股票
+
+ Args:
+ sec: 证券代码
+ price: 委买价。如果为None,则自动转市价买入。
+ vol: 委买股数。请自行保证为100的整数。如果为None, 则money必须传入。
+ money: 委买金额。如果同时传入了vol,则此参数自动忽略
+ order_time: 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略
+ Returns:
+ 见traderclient中的`buy`方法。
+ """
+ logger.debug(
+ "buy order: %s, %s, %s, %s",
+ sec,
+ f"{price:.2f}" if price is not None else None,
+ f"{vol:.0f}" if vol is not None else None,
+ f"{money:.0f}" if money is not None else None,
+ date=order_time,
+ )
+
+ if vol is None:
+ if money is None:
+ raise ValueError("parameter `mnoey` must be presented!")
+
+ return await self.broker.buy_by_money(
+ sec, money, price, order_time=order_time
+ )
+ elif price is None:
+ return self.broker.market_buy(sec, vol, order_time=order_time)
+ else:
+ return self.broker.buy(sec, price, vol, order_time=order_time)
+
+ async def sell(
+ self,
+ sec: str,
+ price: Optional[float] = None,
+ vol: Optional[float] = None,
+ percent: Optional[float] = None,
+ order_time: Optional[datetime.datetime] = None,
+ ) -> Union[List, Dict]:
+ """卖出股票
+
+ Args:
+ sec: 证券代码
+ price: 委卖价,如果未提供,则转为市价单
+ vol: 委卖股数。如果为None,则percent必须传入
+ percent: 卖出一定比例的持仓,取值介于0与1之间。如果与vol同时提供,此参数将被忽略。请自行保证按比例换算后的卖出数据是符合要求的(比如不为100的倍数,但有些情况下这是允许的,所以程序这里无法帮你判断)
+ order_time: 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略
+
+ Returns:
+ Union[List, Dict]: 成交返回,详见traderclient中的`buy`方法,trade server只返回一个委托单信息
+ """
+ logger.debug(
+ "sell order: %s, %s, %s, %s",
+ sec,
+ f"{price:.2f}" if price is not None else None,
+ f"{vol:.0f}" if vol is not None else None,
+ f"{percent:.2%}" if percent is not None else None,
+ date=order_time,
+ )
+
+ if vol is None and percent is None:
+ raise ValueError("either vol or percent must be presented")
+
+ if vol is None:
+ if price is None:
+ price = await self.broker._get_market_sell_price(
+ sec, order_time=order_time
+ )
+ # there's no market_sell_percent API in traderclient
+ return self.broker.sell_percent(sec, price, percent, order_time=order_time) # type: ignore
+ else:
+ if price is None:
+ return self.broker.market_sell(sec, vol, order_time=order_time)
+ else:
+ return self.broker.sell(sec, price, vol, order_time=order_time)
+
+ async def filter_paused_stock(self, buylist: List[str], dt: datetime.date):
+ secs = await Security.select(dt).eval()
+ in_trading = jq.get_price(
+ secs, fields=["paused"], start_date=dt, end_date=dt, skip_paused=True
+ )["code"].to_numpy()
+
+ return np.intersect1d(buylist, in_trading)
+
+ async def before_start(self):
+ """策略启动前的准备工作。
+
+ 在一次回测中,它会在backtest中、进入循环之前调用。如果策略需要根据过去的数据来计算一些自适应参数,可以在此方法中实现。
+ """
+ if self.bs is not None:
+ logger.info(
+ "BEFORE_START: %s<%s - %s>",
+ self.name,
+ self.bs.start,
+ self.bs.end,
+ date=self.bs.start,
+ )
+ else:
+ logger.info("BEFORE_START: %s", self.name)
+
+ async def before_trade(self, date: datetime.date, barss: Optional[Dict[str, BarsArray]]=None):
+ """每日开盘前的准备工作
+
+ Args:
+ date: 日期。在回测中为回测当日日期,在实盘中为系统日期
+ barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
+ """
+ logger.debug("BEFORE_TRADE: %s", self.name, date=date)
+
+ async def after_trade(self, date: Frame, barss: Optional[Dict[str, BarsArray]]=None):
+ """每日收盘后的收尾工作
+
+ Args:
+ date: 日期。在回测中为回测当日日期,在实盘中为系统日期
+ barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
+ """
+ logger.debug("AFTER_TRADE: %s", self.name, date=date)
+
+ async def after_stop(self):
+ if self.bs is not None:
+ logger.info(
+ "STOP %s<%s - %s>",
+ self.name,
+ self.bs.start,
+ self.bs.end,
+ date=self.bs.end,
+ )
+ else:
+ logger.info("STOP %s", self.name)
+
+ async def predict(
+ self,
+ frame: Frame,
+ frame_type: FrameType,
+ i: int,
+ barss: Optional[Dict[str, BarsArray]] = None,
+ **kwargs,
+ ):
+ """策略评估函数。在此函数中实现交易信号检测和处理。
+
+ Args:
+ frame: 当前时间帧
+ frame_type: 处理的数据主周期
+ i: 当前时间离回测起始的单位数
+ barss: 如果调用`backtest`时传入了`portfolio`及参数,则`backtest`将会在回测之前,预取从[start - warmup_period * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - warmup_period * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。
+
+ Keyword Args: 在`backtest`方法中的传入的kwargs参数将被透传到此方法中。
+ """
+ raise NotImplementedError
+
+ @deprecated("2.0.0", details="use `make_report` instead")
+ async def plot_metrics(
+ self, indicator: Union[pd.DataFrame, List[Tuple], None] = None
+ ):
+ return await self.make_report(indicator)
+
+ async def make_report(
+ self, indicator: Union[pd.DataFrame, List[Tuple], None] = None
+ ):
+ """策略回测报告
+
+ Args:
+ indicator: 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标列名为"value"的DataFrame
+ """
+ if self.bills is None or self.metrics is None:
+ raise ValueError("Please run `start_backtest` first.")
+
+ if isinstance(indicator, list):
+ assert len(indicator[0]) == 2
+ indicator = pd.DataFrame(indicator, columns=["date", "value"])
+ indicator.set_index("date", inplace=True)
+
+ mg = MetricsGraph(
+ self.bills,
+ self.metrics,
+ indicator=indicator,
+ baseline_code=self.bs.baseline,
+ )
+ await mg.plot()
+
cash
+
+
+ property
+ readonly
+
+
+¶返回当前可用现金
+__init__(self, url, account=None, token=None, name=None, ver=None, is_backtest=True, start=None, end=None, frame_type=None, warmup_period=0)
+
+
+ special
+
+
+¶构造函数
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
url |
+ str |
+ 实盘/回测服务器的地址。 |
+ required | +
start |
+ Union[datetime.date, datetime.datetime] |
+ 回测起始日期。回测模式下必须传入。 |
+ None |
+
end |
+ Union[datetime.date, datetime.datetime] |
+ 回测结束日期。回测模式下必须传入。 |
+ None |
+
account |
+ Optional[str] |
+ 实盘/回测账号。实盘模式下必须传入。在回测模式下,如果未传入,将以策略名+随机字符构建账号。 |
+ None |
+
token |
+ Optional[str] |
+ 实盘/回测时用的token。实盘模式下必须传入。在回测模式下,如果未传入,将自动生成。 |
+ None |
+
is_backtest |
+ bool |
+ 是否为回测模式。 |
+ True |
+
name |
+ Optional[str] |
+ 策略名。如果不传入,则使用类名字小写 |
+ None |
+
ver |
+ Optional[str] |
+ 策略版本号。如果不传入,则默认为0.1. |
+ None |
+
start |
+ Union[datetime.date, datetime.datetime] |
+ 如果是回测模式,则需要提供回测起始时间 |
+ None |
+
end |
+ Union[datetime.date, datetime.datetime] |
+ 如果是回测模式,则需要提供回测结束时间 |
+ None |
+
frame_type |
+ Optional[coretypes.types.FrameType] |
+ 如果是回测模式,则需要提供回测时使用的主周期 |
+ None |
+
warmup_period |
+ int |
+ 策略执行时需要的最小bar数(以frame_type)计。 |
+ 0 |
+
omicron/strategy/base.py
def __init__(
+ self,
+ url: str,
+ account: Optional[str] = None,
+ token: Optional[str] = None,
+ name: Optional[str] = None,
+ ver: Optional[str] = None,
+ is_backtest: bool = True,
+ start: Optional[Frame] = None,
+ end: Optional[Frame] = None,
+ frame_type: Optional[FrameType] = None,
+ warmup_period: int = 0,
+):
+ """构造函数
+
+ Args:
+ url: 实盘/回测服务器的地址。
+ start: 回测起始日期。回测模式下必须传入。
+ end: 回测结束日期。回测模式下必须传入。
+ account: 实盘/回测账号。实盘模式下必须传入。在回测模式下,如果未传入,将以策略名+随机字符构建账号。
+ token: 实盘/回测时用的token。实盘模式下必须传入。在回测模式下,如果未传入,将自动生成。
+ is_backtest: 是否为回测模式。
+ name: 策略名。如果不传入,则使用类名字小写
+ ver: 策略版本号。如果不传入,则默认为0.1.
+ start: 如果是回测模式,则需要提供回测起始时间
+ end: 如果是回测模式,则需要提供回测结束时间
+ frame_type: 如果是回测模式,则需要提供回测时使用的主周期
+ warmup_period: 策略执行时需要的最小bar数(以frame_type)计。
+ """
+ self.ver = ver or "0.1"
+ self.name = name or self.__class__.__name__.lower() + f"_v{self.ver}"
+
+ self.token = token or uuid.uuid4().hex
+ self.account = account or f"smallcap-{self.token[-4:]}"
+
+ self.url = url
+ self.bills = None
+ self.metrics = None
+
+ # used by both live and backtest
+ self.warmup_period = warmup_period
+ self.is_backtest = is_backtest
+ if is_backtest:
+ if start is None or end is None or frame_type is None:
+ raise ValueError("start, end and frame_type must be presented.")
+
+ self.bs = BacktestState(start, end, None, 0, warmup_period)
+ self._frame_type = frame_type
+ self.broker = TraderClient(
+ url,
+ self.account,
+ self.token,
+ is_backtest=True,
+ start=self.bs.start,
+ end=self.bs.end,
+ )
+ else:
+ if account is None or token is None:
+ raise ValueError("account and token must be presented.")
+
+ self.broker = TraderClient(url, self.account, self.token, is_backtest=False)
+
after_trade(self, date, barss=None)
+
+
+ async
+
+
+¶每日收盘后的收尾工作
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
date |
+ Union[datetime.date, datetime.datetime] |
+ 日期。在回测中为回测当日日期,在实盘中为系统日期 |
+ required | +
barss |
+ Optional[Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]]] |
+ 如果主周期为日线,且支持预取,则会将预取的barss传入 |
+ None |
+
omicron/strategy/base.py
async def after_trade(self, date: Frame, barss: Optional[Dict[str, BarsArray]]=None):
+ """每日收盘后的收尾工作
+
+ Args:
+ date: 日期。在回测中为回测当日日期,在实盘中为系统日期
+ barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
+ """
+ logger.debug("AFTER_TRADE: %s", self.name, date=date)
+
available_shares(self, sec, dt=None)
+
+
+¶返回给定股票在dt
日的可售股数
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sec |
+ str |
+ 证券代码 |
+ required | +
dt |
+ Union[datetime.date, datetime.datetime] |
+ 日期,在实盘中无意义,只能返回最新数据;在回测时,必须指定日期,且返回指定日期下的持仓。 |
+ None |
+
omicron/strategy/base.py
def available_shares(self, sec: str, dt: Optional[Frame] = None):
+ """返回给定股票在`dt`日的可售股数
+
+ Args:
+ sec: 证券代码
+ dt: 日期,在实盘中无意义,只能返回最新数据;在回测时,必须指定日期,且返回指定日期下的持仓。
+ """
+ return self.broker.available_shares(sec, dt)
+
backtest(self, stop_on_error=True, **kwargs)
+
+
+ async
+
+
+¶执行回测
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
stop_on_error |
+ bool |
+ 如果为True,则发生异常时,将停止回测。否则忽略错误,继续执行。 |
+ True |
+
Keyword arguments:
+Name | +Type | +Description | +
---|---|---|
prefetch_stocks |
+ Dict[str, BarsArray] |
+ 代码列表。在该列表中的品种,将在回测之前自动预取行情数据,并在调用predict时,传入截止到当前frame的,长度为n的行情数据。行情周期由构造时的frame_type指定。预取数据长度由 |
+
omicron/strategy/base.py
async def backtest(self, stop_on_error: bool = True, **kwargs):
+ """执行回测
+
+ Args:
+ stop_on_error: 如果为True,则发生异常时,将停止回测。否则忽略错误,继续执行。
+ Keyword Args:
+ prefetch_stocks Dict[str, BarsArray]: 代码列表。在该列表中的品种,将在回测之前自动预取行情数据,并在调用predict时,传入截止到当前frame的,长度为n的行情数据。行情周期由构造时的frame_type指定。预取数据长度由`self.warmup_period`决定
+ """
+ prefetch_stocks: List[str] = kwargs.get("prefetch_stocks") # type: ignore
+ await self._cache_bars_for_backtest(prefetch_stocks, self.warmup_period)
+ self.bs.cursor = self.warmup_period
+
+ intra_day = self._frame_type in tf.minute_level_frames
+ converter = tf.int2time if intra_day else tf.int2date
+
+ await self.before_start()
+
+ # 最后一周期不做预测,留出来执行上一周期的信号
+ end_ = tf.shift(self.bs.end, -1, self._frame_type)
+ for i, frame in enumerate(
+ tf.get_frames(self.bs.start, end_, self._frame_type) # type: ignore
+ ):
+ barss = self._next()
+ day_barss = barss if self._frame_type == FrameType.DAY else None
+ frame_ = converter(frame)
+
+ prev_frame = tf.shift(frame_, -1, self._frame_type)
+ next_frame = tf.shift(frame_, 1, self._frame_type)
+
+ # new trading day start
+ if (not intra_day and prev_frame < frame_) or (
+ intra_day and prev_frame.date() < frame_.date()
+ ):
+ await self.before_trade(frame_, day_barss)
+
+ logger.debug("%sth iteration", i, date=frame_)
+ try:
+ await self.predict(
+ frame_, self._frame_type, i, barss=barss, **kwargs # type: ignore
+ )
+ except Exception as e:
+ if isinstance(e, TradeError):
+ logger.warning("call stack is:\n%s", e.stack)
+ else:
+ logger.exception(e)
+ if stop_on_error:
+ raise e
+
+ # trading day ends
+ if (not intra_day and next_frame > frame_) or (
+ intra_day and next_frame.date() > frame_.date()
+ ):
+ await self.after_trade(frame_, day_barss)
+
+ self.broker.stop_backtest()
+
+ await self.after_stop()
+ self.bills = self.broker.bills()
+ baseline = kwargs.get("baseline", "399300.XSHE")
+ self.metrics = self.broker.metrics(baseline=baseline)
+ self.bs.baseline = baseline
+
before_start(self)
+
+
+ async
+
+
+¶策略启动前的准备工作。
+在一次回测中,它会在backtest中、进入循环之前调用。如果策略需要根据过去的数据来计算一些自适应参数,可以在此方法中实现。
+ +omicron/strategy/base.py
async def before_start(self):
+ """策略启动前的准备工作。
+
+ 在一次回测中,它会在backtest中、进入循环之前调用。如果策略需要根据过去的数据来计算一些自适应参数,可以在此方法中实现。
+ """
+ if self.bs is not None:
+ logger.info(
+ "BEFORE_START: %s<%s - %s>",
+ self.name,
+ self.bs.start,
+ self.bs.end,
+ date=self.bs.start,
+ )
+ else:
+ logger.info("BEFORE_START: %s", self.name)
+
before_trade(self, date, barss=None)
+
+
+ async
+
+
+¶每日开盘前的准备工作
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
date |
+ date |
+ 日期。在回测中为回测当日日期,在实盘中为系统日期 |
+ required | +
barss |
+ Optional[Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]]] |
+ 如果主周期为日线,且支持预取,则会将预取的barss传入 |
+ None |
+
omicron/strategy/base.py
async def before_trade(self, date: datetime.date, barss: Optional[Dict[str, BarsArray]]=None):
+ """每日开盘前的准备工作
+
+ Args:
+ date: 日期。在回测中为回测当日日期,在实盘中为系统日期
+ barss: 如果主周期为日线,且支持预取,则会将预取的barss传入
+ """
+ logger.debug("BEFORE_TRADE: %s", self.name, date=date)
+
buy(self, sec, price=None, vol=None, money=None, order_time=None)
+
+
+ async
+
+
+¶买入股票
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sec |
+ str |
+ 证券代码 |
+ required | +
price |
+ Optional[float] |
+ 委买价。如果为None,则自动转市价买入。 |
+ None |
+
vol |
+ Optional[int] |
+ 委买股数。请自行保证为100的整数。如果为None, 则money必须传入。 |
+ None |
+
money |
+ Optional[float] |
+ 委买金额。如果同时传入了vol,则此参数自动忽略 |
+ None |
+
order_time |
+ Optional[datetime.datetime] |
+ 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Dict |
+ 见traderclient中的 |
+
omicron/strategy/base.py
async def buy(
+ self,
+ sec: str,
+ price: Optional[float] = None,
+ vol: Optional[int] = None,
+ money: Optional[float] = None,
+ order_time: Optional[datetime.datetime] = None,
+) -> Dict:
+ """买入股票
+
+ Args:
+ sec: 证券代码
+ price: 委买价。如果为None,则自动转市价买入。
+ vol: 委买股数。请自行保证为100的整数。如果为None, 则money必须传入。
+ money: 委买金额。如果同时传入了vol,则此参数自动忽略
+ order_time: 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略
+ Returns:
+ 见traderclient中的`buy`方法。
+ """
+ logger.debug(
+ "buy order: %s, %s, %s, %s",
+ sec,
+ f"{price:.2f}" if price is not None else None,
+ f"{vol:.0f}" if vol is not None else None,
+ f"{money:.0f}" if money is not None else None,
+ date=order_time,
+ )
+
+ if vol is None:
+ if money is None:
+ raise ValueError("parameter `mnoey` must be presented!")
+
+ return await self.broker.buy_by_money(
+ sec, money, price, order_time=order_time
+ )
+ elif price is None:
+ return self.broker.market_buy(sec, vol, order_time=order_time)
+ else:
+ return self.broker.buy(sec, price, vol, order_time=order_time)
+
make_report(self, indicator=None)
+
+
+ async
+
+
+¶策略回测报告
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
indicator |
+ Union[pandas.core.frame.DataFrame, List[Tuple]] |
+ 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标列名为"value"的DataFrame |
+ None |
+
omicron/strategy/base.py
async def make_report(
+ self, indicator: Union[pd.DataFrame, List[Tuple], None] = None
+):
+ """策略回测报告
+
+ Args:
+ indicator: 回测时使用的指标。如果存在,将叠加到策略回测图上。它应该是一个以日期为索引,指标列名为"value"的DataFrame
+ """
+ if self.bills is None or self.metrics is None:
+ raise ValueError("Please run `start_backtest` first.")
+
+ if isinstance(indicator, list):
+ assert len(indicator[0]) == 2
+ indicator = pd.DataFrame(indicator, columns=["date", "value"])
+ indicator.set_index("date", inplace=True)
+
+ mg = MetricsGraph(
+ self.bills,
+ self.metrics,
+ indicator=indicator,
+ baseline_code=self.bs.baseline,
+ )
+ await mg.plot()
+
peek(self, code, n)
+
+
+ async
+
+
+¶允许策略偷看未来数据
+可用以因子检验场景。要求数据本身已缓存。否则请用Stock.get_bars等方法获取。
+ +omicron/strategy/base.py
async def peek(self, code: str, n: int):
+ """允许策略偷看未来数据
+
+ 可用以因子检验场景。要求数据本身已缓存。否则请用Stock.get_bars等方法获取。
+ """
+ if self.bs is None or self.bs.barss is None:
+ raise ValueError("data is not cached")
+
+ if code in self.bs.barss:
+ if self.bs.cursor + n + 1 < len(self.bs.barss[code]):
+ return Stock.qfq(
+ self.bs.barss[code][self.bs.cursor : self.bs.cursor + n]
+ )
+
+ else:
+ raise ValueError("data is not cached")
+
plot_metrics(self, indicator=None)
+
+
+ async
+
+
+¶.. deprecated:: 2.0.0 use make_report
instead
omicron/strategy/base.py
@deprecated("2.0.0", details="use `make_report` instead")
+async def plot_metrics(
+ self, indicator: Union[pd.DataFrame, List[Tuple], None] = None
+):
+ return await self.make_report(indicator)
+
positions(self, dt=None)
+
+
+¶返回当前持仓
+ +omicron/strategy/base.py
def positions(self, dt: Optional[datetime.date] = None):
+ """返回当前持仓"""
+ return self.broker.positions(dt)
+
predict(self, frame, frame_type, i, barss=None, **kwargs)
+
+
+ async
+
+
+¶策略评估函数。在此函数中实现交易信号检测和处理。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame |
+ Union[datetime.date, datetime.datetime] |
+ 当前时间帧 |
+ required | +
frame_type |
+ FrameType |
+ 处理的数据主周期 |
+ required | +
i |
+ int |
+ 当前时间离回测起始的单位数 |
+ required | +
barss |
+ Optional[Dict[str, numpy.ndarray[Any, numpy.dtype[dtype([('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')])]]]] |
+ 如果调用 |
+ None |
+
Keyword Args: 在backtest
方法中的传入的kwargs参数将被透传到此方法中。
omicron/strategy/base.py
async def predict(
+ self,
+ frame: Frame,
+ frame_type: FrameType,
+ i: int,
+ barss: Optional[Dict[str, BarsArray]] = None,
+ **kwargs,
+):
+ """策略评估函数。在此函数中实现交易信号检测和处理。
+
+ Args:
+ frame: 当前时间帧
+ frame_type: 处理的数据主周期
+ i: 当前时间离回测起始的单位数
+ barss: 如果调用`backtest`时传入了`portfolio`及参数,则`backtest`将会在回测之前,预取从[start - warmup_period * frame_type, end]间的portfolio行情数据,并在每次调用`predict`方法时,通过`barss`参数,将[start - warmup_period * frame_type, start + i * frame_type]间的数据传给`predict`方法。传入的数据已进行前复权。
+
+ Keyword Args: 在`backtest`方法中的传入的kwargs参数将被透传到此方法中。
+ """
+ raise NotImplementedError
+
sell(self, sec, price=None, vol=None, percent=None, order_time=None)
+
+
+ async
+
+
+¶卖出股票
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
sec |
+ str |
+ 证券代码 |
+ required | +
price |
+ Optional[float] |
+ 委卖价,如果未提供,则转为市价单 |
+ None |
+
vol |
+ Optional[float] |
+ 委卖股数。如果为None,则percent必须传入 |
+ None |
+
percent |
+ Optional[float] |
+ 卖出一定比例的持仓,取值介于0与1之间。如果与vol同时提供,此参数将被忽略。请自行保证按比例换算后的卖出数据是符合要求的(比如不为100的倍数,但有些情况下这是允许的,所以程序这里无法帮你判断) |
+ None |
+
order_time |
+ Optional[datetime.datetime] |
+ 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Union[List, Dict] |
+ 成交返回,详见traderclient中的 |
+
omicron/strategy/base.py
async def sell(
+ self,
+ sec: str,
+ price: Optional[float] = None,
+ vol: Optional[float] = None,
+ percent: Optional[float] = None,
+ order_time: Optional[datetime.datetime] = None,
+) -> Union[List, Dict]:
+ """卖出股票
+
+ Args:
+ sec: 证券代码
+ price: 委卖价,如果未提供,则转为市价单
+ vol: 委卖股数。如果为None,则percent必须传入
+ percent: 卖出一定比例的持仓,取值介于0与1之间。如果与vol同时提供,此参数将被忽略。请自行保证按比例换算后的卖出数据是符合要求的(比如不为100的倍数,但有些情况下这是允许的,所以程序这里无法帮你判断)
+ order_time: 仅在回测模式下需要提供。实盘模式下,此参数自动被忽略
+
+ Returns:
+ Union[List, Dict]: 成交返回,详见traderclient中的`buy`方法,trade server只返回一个委托单信息
+ """
+ logger.debug(
+ "sell order: %s, %s, %s, %s",
+ sec,
+ f"{price:.2f}" if price is not None else None,
+ f"{vol:.0f}" if vol is not None else None,
+ f"{percent:.2%}" if percent is not None else None,
+ date=order_time,
+ )
+
+ if vol is None and percent is None:
+ raise ValueError("either vol or percent must be presented")
+
+ if vol is None:
+ if price is None:
+ price = await self.broker._get_market_sell_price(
+ sec, order_time=order_time
+ )
+ # there's no market_sell_percent API in traderclient
+ return self.broker.sell_percent(sec, price, percent, order_time=order_time) # type: ignore
+ else:
+ if price is None:
+ return self.broker.market_sell(sec, vol, order_time=order_time)
+ else:
+ return self.broker.sell(sec, price, vol, order_time=order_time)
+
sma
+
+
+
+¶
+SMAStrategy (BaseStrategy)
+
+
+
+
+¶omicron/strategy/sma.py
class SMAStrategy(BaseStrategy):
+ def __init__(self, sec: str, n_short: int = 5, n_long: int = 10, *args, **kwargs):
+ self._sec = sec
+ self._n_short = n_short
+ self._n_long = n_long
+
+ self.indicators = []
+
+ super().__init__(*args, **kwargs)
+
+ async def before_start(self):
+ date = self.bs.end if self.bs is not None else None
+ logger.info("before_start, cash is %s", self.cash, date=date)
+
+ async def before_trade(self, date: datetime.date):
+ logger.info(
+ "before_trade, cash is %s, portfolio is %s",
+ self.cash,
+ self.positions(date),
+ date=date,
+ )
+
+ async def after_trade(self, date: datetime.date):
+ logger.info(
+ "after_trade, cash is %s, portfolio is %s",
+ self.cash,
+ self.positions(date),
+ date=date,
+ )
+
+ async def after_stop(self):
+ date = self.bs.end if self.bs is not None else None
+ logger.info(
+ "after_stop, cash is %s, portfolio is %s",
+ self.cash,
+ self.positions,
+ date=date,
+ )
+
+ async def predict(
+ self, frame: Frame, frame_type: FrameType, i: int, barss, **kwargs
+ ):
+ if barss is None:
+ raise ValueError("please specify `prefetch_stocks`")
+
+ bars: Union[BarsArray, None] = barss.get(self._sec)
+ if bars is None:
+ raise ValueError(f"{self._sec} not found in `prefetch_stocks`")
+
+ ma_short = np.mean(bars["close"][-self._n_short :])
+ ma_long = np.mean(bars["close"][-self._n_long :])
+
+ if ma_short > ma_long:
+ self.indicators.append((frame, 1))
+ if self.cash >= 100 * bars["close"][-1]:
+ await self.buy(
+ self._sec,
+ money=self.cash,
+ order_time=tf.combine_time(frame, 14, 55),
+ )
+ elif ma_short < ma_long:
+ self.indicators.append((frame, -1))
+ if self.available_shares(self._sec, frame) > 0:
+ await self.sell(
+ self._sec, percent=1.0, order_time=tf.combine_time(frame, 14, 55)
+ )
+
after_trade(self, date)
+
+
+ async
+
+
+¶每日收盘后的收尾工作
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
date |
+ date |
+ 日期。在回测中为回测当日日期,在实盘中为系统日期 |
+ required | +
barss |
+ + | 如果主周期为日线,且支持预取,则会将预取的barss传入 |
+ required | +
omicron/strategy/sma.py
async def after_trade(self, date: datetime.date):
+ logger.info(
+ "after_trade, cash is %s, portfolio is %s",
+ self.cash,
+ self.positions(date),
+ date=date,
+ )
+
before_start(self)
+
+
+ async
+
+
+¶策略启动前的准备工作。
+在一次回测中,它会在backtest中、进入循环之前调用。如果策略需要根据过去的数据来计算一些自适应参数,可以在此方法中实现。
+ +omicron/strategy/sma.py
async def before_start(self):
+ date = self.bs.end if self.bs is not None else None
+ logger.info("before_start, cash is %s", self.cash, date=date)
+
before_trade(self, date)
+
+
+ async
+
+
+¶每日开盘前的准备工作
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
date |
+ date |
+ 日期。在回测中为回测当日日期,在实盘中为系统日期 |
+ required | +
barss |
+ + | 如果主周期为日线,且支持预取,则会将预取的barss传入 |
+ required | +
omicron/strategy/sma.py
async def before_trade(self, date: datetime.date):
+ logger.info(
+ "before_trade, cash is %s, portfolio is %s",
+ self.cash,
+ self.positions(date),
+ date=date,
+ )
+
predict(self, frame, frame_type, i, barss, **kwargs)
+
+
+ async
+
+
+¶策略评估函数。在此函数中实现交易信号检测和处理。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame |
+ Union[datetime.date, datetime.datetime] |
+ 当前时间帧 |
+ required | +
frame_type |
+ FrameType |
+ 处理的数据主周期 |
+ required | +
i |
+ int |
+ 当前时间离回测起始的单位数 |
+ required | +
barss |
+ + | 如果调用 |
+ required | +
Keyword Args: 在backtest
方法中的传入的kwargs参数将被透传到此方法中。
omicron/strategy/sma.py
async def predict(
+ self, frame: Frame, frame_type: FrameType, i: int, barss, **kwargs
+):
+ if barss is None:
+ raise ValueError("please specify `prefetch_stocks`")
+
+ bars: Union[BarsArray, None] = barss.get(self._sec)
+ if bars is None:
+ raise ValueError(f"{self._sec} not found in `prefetch_stocks`")
+
+ ma_short = np.mean(bars["close"][-self._n_short :])
+ ma_long = np.mean(bars["close"][-self._n_long :])
+
+ if ma_short > ma_long:
+ self.indicators.append((frame, 1))
+ if self.cash >= 100 * bars["close"][-1]:
+ await self.buy(
+ self._sec,
+ money=self.cash,
+ order_time=tf.combine_time(frame, 14, 55),
+ )
+ elif ma_short < ma_long:
+ self.indicators.append((frame, -1))
+ if self.available_shares(self._sec, frame) > 0:
+ await self.sell(
+ self._sec, percent=1.0, order_time=tf.combine_time(frame, 14, 55)
+ )
+
core
+
+
+
+¶angle(ts, threshold=0.01, loss_func='re')
+
+
+¶求时间序列ts
拟合直线相对于x
轴的夹角的余弦值
本函数可以用来判断时间序列的增长趋势。当angle
处于[-1, 0]时,越靠近0,下降越快;当angle
+处于[0, 1]时,越接近0,上升越快。
如果ts
无法很好地拟合为直线,则返回[float, None]
Examples:
+>>> ts = np.array([ i for i in range(5)])
+>>> round(angle(ts)[1], 3) # degree: 45, rad: pi/2
+0.707
+
>>> ts = np.array([ np.sqrt(3) / 3 * i for i in range(10)])
+>>> round(angle(ts)[1],3) # degree: 30, rad: pi/6
+0.866
+
>>> ts = np.array([ -np.sqrt(3) / 3 * i for i in range(7)])
+>>> round(angle(ts)[1], 3) # degree: 150, rad: 5*pi/6
+-0.866
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ + | + | required | +
Returns:
+Type | +Description | +
---|---|
Tuple[float, float] |
+ 返回 (error, consine(theta)),即拟合误差和夹角余弦值。 |
+
omicron/talib/core.py
def angle(ts, threshold=0.01, loss_func="re") -> Tuple[float, float]:
+ """求时间序列`ts`拟合直线相对于`x`轴的夹角的余弦值
+
+ 本函数可以用来判断时间序列的增长趋势。当`angle`处于[-1, 0]时,越靠近0,下降越快;当`angle`
+ 处于[0, 1]时,越接近0,上升越快。
+
+ 如果`ts`无法很好地拟合为直线,则返回[float, None]
+
+ Examples:
+
+ >>> ts = np.array([ i for i in range(5)])
+ >>> round(angle(ts)[1], 3) # degree: 45, rad: pi/2
+ 0.707
+
+ >>> ts = np.array([ np.sqrt(3) / 3 * i for i in range(10)])
+ >>> round(angle(ts)[1],3) # degree: 30, rad: pi/6
+ 0.866
+
+ >>> ts = np.array([ -np.sqrt(3) / 3 * i for i in range(7)])
+ >>> round(angle(ts)[1], 3) # degree: 150, rad: 5*pi/6
+ -0.866
+
+ Args:
+ ts:
+
+ Returns:
+ 返回 (error, consine(theta)),即拟合误差和夹角余弦值。
+
+ """
+ err, (a, b) = polyfit(ts, deg=1, loss_func=loss_func)
+ if err > threshold:
+ return (err, None)
+
+ v = np.array([1, a + b])
+ vx = np.array([1, 0])
+
+ return err, copysign(np.dot(v, vx) / (norm(v) * norm(vx)), a)
+
clustering(numbers, n)
+
+
+¶将数组numbers
划分为n
个簇
返回值为一个List, 每一个元素为一个列表,分别为簇的起始点和长度。
+ +Examples:
+>>> numbers = np.array([1,1,1,2,4,6,8,7,4,5,6])
+>>> clustering(numbers, 2)
+[(0, 4), (4, 7)]
+
Returns:
+Type | +Description | +
---|---|
List[Tuple[int, int]] |
+ 划分后的簇列表。 |
+
omicron/talib/core.py
def clustering(numbers: np.ndarray, n: int) -> List[Tuple[int, int]]:
+ """将数组`numbers`划分为`n`个簇
+
+ 返回值为一个List, 每一个元素为一个列表,分别为簇的起始点和长度。
+
+ Examples:
+ >>> numbers = np.array([1,1,1,2,4,6,8,7,4,5,6])
+ >>> clustering(numbers, 2)
+ [(0, 4), (4, 7)]
+
+ Returns:
+ 划分后的簇列表。
+ """
+ result = ckwrap.cksegs(numbers, n)
+
+ clusters = []
+ for pos, size in zip(result.centers, result.sizes):
+ clusters.append((int(pos - size // 2 - 1), int(size)))
+
+ return clusters
+
exp_moving_average(values, window)
+
+
+¶Numpy implementation of EMA
+ +omicron/talib/core.py
def exp_moving_average(values, window):
+ """Numpy implementation of EMA"""
+ weights = np.exp(np.linspace(-1.0, 0.0, window))
+ weights /= weights.sum()
+ a = np.convolve(values, weights, mode="full")[: len(values)]
+ a[:window] = a[window]
+
+ return a
+
mean_absolute_error(y, y_hat)
+
+
+¶返回预测序列相对于真值序列的平均绝对值差
+两个序列应该具有相同的长度。如果存在nan,则nan的值不计入平均值。
+ +Examples:
+>>> y = np.arange(5)
+>>> y_hat = np.arange(5)
+>>> y_hat[4] = 0
+>>> mean_absolute_error(y, y)
+0.0
+
>>> mean_absolute_error(y, y_hat)
+0.8
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
y |
+ np.array |
+ 真值序列 |
+ required | +
y_hat |
+ <built-in function array> |
+ 比较序列 |
+ required | +
Returns:
+Type | +Description | +
---|---|
float |
+ 平均绝对值差 |
+
omicron/talib/core.py
def mean_absolute_error(y: np.array, y_hat: np.array) -> float:
+ """返回预测序列相对于真值序列的平均绝对值差
+
+ 两个序列应该具有相同的长度。如果存在nan,则nan的值不计入平均值。
+
+ Examples:
+
+ >>> y = np.arange(5)
+ >>> y_hat = np.arange(5)
+ >>> y_hat[4] = 0
+ >>> mean_absolute_error(y, y)
+ 0.0
+
+ >>> mean_absolute_error(y, y_hat)
+ 0.8
+
+ Args:
+ y (np.array): 真值序列
+ y_hat: 比较序列
+
+ Returns:
+ float: 平均绝对值差
+ """
+ return nanmean(np.abs(y - y_hat))
+
moving_average(ts, win, padding=True)
+
+
+¶生成ts序列的移动平均值
+ +Examples:
+>>> ts = np.arange(7)
+>>> moving_average(ts, 5)
+array([nan, nan, nan, nan, 2., 3., 4.])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ Sequence |
+ the input array |
+ required | +
win |
+ int |
+ the window size |
+ required | +
padding |
+ + | if True, then the return will be equal length as input, padding with np.NaN at the beginning |
+ True |
+
Returns:
+Type | +Description | +
---|---|
ndarray |
+ The moving mean of the input array along the specified axis. The output has the same shape as the input. |
+
omicron/talib/core.py
def moving_average(ts: Sequence, win: int, padding=True) -> np.ndarray:
+ """生成ts序列的移动平均值
+
+ Examples:
+
+ >>> ts = np.arange(7)
+ >>> moving_average(ts, 5)
+ array([nan, nan, nan, nan, 2., 3., 4.])
+
+ Args:
+ ts (Sequence): the input array
+ win (int): the window size
+ padding: if True, then the return will be equal length as input, padding with np.NaN at the beginning
+
+ Returns:
+ The moving mean of the input array along the specified axis. The output has the same shape as the input.
+ """
+ ma = move_mean(ts, win)
+ if padding:
+ return ma
+ else:
+ return ma[win - 1 :]
+
normalize(X, scaler='maxabs')
+
+
+¶对数据进行规范化处理。
+如果scaler为maxabs,则X的各元素被压缩到[-1,1]之间 +如果scaler为unit_vector,则将X的各元素压缩到单位范数 +如果scaler为minmax,则X的各元素被压缩到[0,1]之间 +如果scaler为standard,则X的各元素被压缩到单位方差之间,且均值为零。
+参考 sklearn
+ +Examples:
+>>> X = [[ 1., -1., 2.],
+... [ 2., 0., 0.],
+... [ 0., 1., -1.]]
+
>>> expected = [[ 0.4082, -0.4082, 0.8165],
+... [ 1., 0., 0.],
+... [ 0., 0.7071, -0.7071]]
+
>>> X_hat = normalize(X, scaler='unit_vector')
+>>> np.testing.assert_array_almost_equal(expected, X_hat, decimal=4)
+
>>> expected = [[0.5, -1., 1.],
+... [1., 0., 0.],
+... [0., 1., -0.5]]
+
>>> X_hat = normalize(X, scaler='maxabs')
+>>> np.testing.assert_array_almost_equal(expected, X_hat, decimal = 2)
+
>>> expected = [[0.5 , 0. , 1. ],
+... [1. , 0.5 , 0.33333333],
+... [0. , 1. , 0. ]]
+>>> X_hat = normalize(X, scaler='minmax')
+>>> np.testing.assert_array_almost_equal(expected, X_hat, decimal= 3)
+
>>> X = [[0, 0],
+... [0, 0],
+... [1, 1],
+... [1, 1]]
+>>> expected = [[-1., -1.],
+... [-1., -1.],
+... [ 1., 1.],
+... [ 1., 1.]]
+>>> X_hat = normalize(X, scaler='standard')
+>>> np.testing.assert_array_almost_equal(expected, X_hat, decimal = 3)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
X |
+ 2D array |
+ + | required | +
scaler |
+ str |
+ [description]. Defaults to 'maxabs_scale'. |
+ 'maxabs' |
+
omicron/talib/core.py
def normalize(X, scaler="maxabs"):
+ """对数据进行规范化处理。
+
+ 如果scaler为maxabs,则X的各元素被压缩到[-1,1]之间
+ 如果scaler为unit_vector,则将X的各元素压缩到单位范数
+ 如果scaler为minmax,则X的各元素被压缩到[0,1]之间
+ 如果scaler为standard,则X的各元素被压缩到单位方差之间,且均值为零。
+
+ 参考 [sklearn]
+
+ [sklearn]: https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#results
+
+ Examples:
+
+ >>> X = [[ 1., -1., 2.],
+ ... [ 2., 0., 0.],
+ ... [ 0., 1., -1.]]
+
+ >>> expected = [[ 0.4082, -0.4082, 0.8165],
+ ... [ 1., 0., 0.],
+ ... [ 0., 0.7071, -0.7071]]
+
+ >>> X_hat = normalize(X, scaler='unit_vector')
+ >>> np.testing.assert_array_almost_equal(expected, X_hat, decimal=4)
+
+ >>> expected = [[0.5, -1., 1.],
+ ... [1., 0., 0.],
+ ... [0., 1., -0.5]]
+
+ >>> X_hat = normalize(X, scaler='maxabs')
+ >>> np.testing.assert_array_almost_equal(expected, X_hat, decimal = 2)
+
+ >>> expected = [[0.5 , 0. , 1. ],
+ ... [1. , 0.5 , 0.33333333],
+ ... [0. , 1. , 0. ]]
+ >>> X_hat = normalize(X, scaler='minmax')
+ >>> np.testing.assert_array_almost_equal(expected, X_hat, decimal= 3)
+
+ >>> X = [[0, 0],
+ ... [0, 0],
+ ... [1, 1],
+ ... [1, 1]]
+ >>> expected = [[-1., -1.],
+ ... [-1., -1.],
+ ... [ 1., 1.],
+ ... [ 1., 1.]]
+ >>> X_hat = normalize(X, scaler='standard')
+ >>> np.testing.assert_array_almost_equal(expected, X_hat, decimal = 3)
+
+ Args:
+ X (2D array):
+ scaler (str, optional): [description]. Defaults to 'maxabs_scale'.
+ """
+ if scaler == "maxabs":
+ return MaxAbsScaler().fit_transform(X)
+ elif scaler == "unit_vector":
+ return sklearn.preprocessing.normalize(X, norm="l2")
+ elif scaler == "minmax":
+ return minmax_scale(X)
+ elif scaler == "standard":
+ return StandardScaler().fit_transform(X)
+
pct_error(y, y_hat)
+
+
+¶相对于序列算术均值的误差值
+ +Examples:
+>>> y = np.arange(5)
+>>> y_hat = np.arange(5)
+>>> y_hat[4] = 0
+>>> pct_error(y, y_hat)
+0.4
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
y |
+ np.array |
+ [description] |
+ required | +
y_hat |
+ np.array |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
float |
+ [description] |
+
omicron/talib/core.py
def pct_error(y: np.array, y_hat: np.array) -> float:
+ """相对于序列算术均值的误差值
+
+ Examples:
+ >>> y = np.arange(5)
+ >>> y_hat = np.arange(5)
+ >>> y_hat[4] = 0
+ >>> pct_error(y, y_hat)
+ 0.4
+
+ Args:
+ y (np.array): [description]
+ y_hat (np.array): [description]
+
+ Returns:
+ float: [description]
+ """
+ mae = mean_absolute_error(y, y_hat)
+ return mae / nanmean(np.abs(y))
+
polyfit(ts, deg=2, loss_func='re')
+
+
+¶对给定的时间序列进行直线/二次曲线拟合。
+二次曲线可以拟合到反生反转的行情,如圆弧底、圆弧顶;也可以拟合到上述趋势中的单边走势,即其中一段曲线。对于如长期均线,在一段时间内走势可能呈现为一条直线,故也可用此函数进行直线拟合。
+为便于在不同品种、不同的时间之间对误差、系数进行比较,请事先对ts进行归一化。 +如果遇到无法拟合的情况(异常),将返回一个非常大的误差,并将其它项置为np.nan
+ +Examples:
+>>> ts = [i for i in range(5)]
+>>> err, (a, b) = polyfit(ts, deg=1)
+>>> print(round(err, 3), round(a, 1))
+0.0 1.0
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ Sequence |
+ 待拟合的时间序列 |
+ required | +
deg |
+ int |
+ 如果要进行直线拟合,取1;二次曲线拟合取2. Defaults to 2 |
+ 2 |
+
loss_func |
+ str |
+ 误差计算方法,取值为 |
+ 're' |
+
Returns:
+Type | +Description | +
---|---|
[Tuple] |
+ 如果为直线拟合,返回误差,(a,b)(一次项系数和常数)。如果为二次曲线拟合,返回 +误差, (a,b,c)(二次项、一次项和常量), (vert_x, vert_y)(顶点处的index,顶点值) |
+
omicron/talib/core.py
def polyfit(ts: Sequence, deg: int = 2, loss_func="re") -> Tuple:
+ """对给定的时间序列进行直线/二次曲线拟合。
+
+ 二次曲线可以拟合到反生反转的行情,如圆弧底、圆弧顶;也可以拟合到上述趋势中的单边走势,即其中一段曲线。对于如长期均线,在一段时间内走势可能呈现为一条直线,故也可用此函数进行直线拟合。
+
+ 为便于在不同品种、不同的时间之间对误差、系数进行比较,请事先对ts进行归一化。
+ 如果遇到无法拟合的情况(异常),将返回一个非常大的误差,并将其它项置为np.nan
+
+ Examples:
+ >>> ts = [i for i in range(5)]
+ >>> err, (a, b) = polyfit(ts, deg=1)
+ >>> print(round(err, 3), round(a, 1))
+ 0.0 1.0
+
+ Args:
+ ts (Sequence): 待拟合的时间序列
+ deg (int): 如果要进行直线拟合,取1;二次曲线拟合取2. Defaults to 2
+ loss_func (str): 误差计算方法,取值为`mae`, `rmse`,`mse` 或`re`。Defaults to `re` (relative_error)
+ Returns:
+ [Tuple]: 如果为直线拟合,返回误差,(a,b)(一次项系数和常数)。如果为二次曲线拟合,返回
+ 误差, (a,b,c)(二次项、一次项和常量), (vert_x, vert_y)(顶点处的index,顶点值)
+ """
+ if deg not in (1, 2):
+ raise ValueError("deg must be 1 or 2")
+
+ try:
+ if any(np.isnan(ts)):
+ raise ValueError("ts contains nan")
+
+ x = np.array(list(range(len(ts))))
+
+ z = np.polyfit(x, ts, deg=deg)
+
+ p = np.poly1d(z)
+ ts_hat = np.array([p(xi) for xi in x])
+
+ if loss_func == "mse":
+ error = np.mean(np.square(ts - ts_hat))
+ elif loss_func == "rmse":
+ error = np.sqrt(np.mean(np.square(ts - ts_hat)))
+ elif loss_func == "mae":
+ error = mean_absolute_error(ts, ts_hat)
+ else: # defaults to relative error
+ error = pct_error(ts, ts_hat)
+
+ if deg == 2:
+ a, b, c = z[0], z[1], z[2]
+ axis_x = -b / (2 * a)
+ if a != 0:
+ axis_y = (4 * a * c - b * b) / (4 * a)
+ else:
+ axis_y = None
+ return error, z, (axis_x, axis_y)
+ elif deg == 1:
+ return error, z
+ except Exception:
+ error = 1e9
+ if deg == 1:
+ return error, (np.nan, np.nan)
+ else:
+ return error, (np.nan, np.nan, np.nan), (np.nan, np.nan)
+
slope(ts, loss_func='re')
+
+
+¶求ts表示的直线(如果能拟合成直线的话)的斜率
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
loss_func |
+ str |
+ [description]. Defaults to 're'. |
+ 're' |
+
omicron/talib/core.py
def slope(ts: np.array, loss_func="re"):
+ """求ts表示的直线(如果能拟合成直线的话)的斜率
+
+ Args:
+ ts (np.array): [description]
+ loss_func (str, optional): [description]. Defaults to 're'.
+ """
+ err, (a, b) = polyfit(ts, deg=1, loss_func=loss_func)
+
+ return err, a
+
smooth(ts, win, poly_order=1, mode='interp')
+
+
+¶平滑序列ts,使用窗口大小为win的平滑模型,默认使用线性模型
+提供本函数主要基于这样的考虑: omicron的使用者可能并不熟悉信号处理的概念,这里相当于提供了相关功能的一个入口。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
win |
+ int |
+ [description] |
+ required | +
poly_order |
+ int |
+ [description]. Defaults to 1. |
+ 1 |
+
omicron/talib/core.py
def smooth(ts: np.array, win: int, poly_order=1, mode="interp"):
+ """平滑序列ts,使用窗口大小为win的平滑模型,默认使用线性模型
+
+ 提供本函数主要基于这样的考虑: omicron的使用者可能并不熟悉信号处理的概念,这里相当于提供了相关功能的一个入口。
+
+ Args:
+ ts (np.array): [description]
+ win (int): [description]
+ poly_order (int, optional): [description]. Defaults to 1.
+ """
+ return savgol_filter(ts, win, poly_order, mode=mode)
+
weighted_moving_average(ts, win)
+
+
+¶计算加权移动平均
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.array |
+ [description] |
+ required | +
win |
+ int |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
np.array |
+ [description] |
+
omicron/talib/core.py
def weighted_moving_average(ts: np.array, win: int) -> np.array:
+ """计算加权移动平均
+
+ Args:
+ ts (np.array): [description]
+ win (int): [description]
+
+ Returns:
+ np.array: [description]
+ """
+ w = [2 * (i + 1) / (win * (win + 1)) for i in range(win)]
+
+ return np.convolve(ts, w, "valid")
+
morph
+
+
+
+¶形态检测相关方法
+ + + +
+BreakoutFlag (IntEnum)
+
+
+
+
+¶An enumeration.
+ +omicron/talib/morph.py
class BreakoutFlag(IntEnum):
+ UP = 1
+ DOWN = -1
+ NONE = 0
+
+CrossFlag (IntEnum)
+
+
+
+
+¶An enumeration.
+ +omicron/talib/morph.py
class CrossFlag(IntEnum):
+ UPCROSS = 1
+ DOWNCROSS = -1
+ NONE = 0
+
breakout(ts, upthres=0.01, downthres=-0.01, confirm=1)
+
+
+¶检测时间序列是否突破了压力线(整理线)
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.ndarray |
+ 时间序列 |
+ required | +
upthres |
+ float |
+ + | 0.01 |
+
downthres |
+ float |
+ + | -0.01 |
+
confirm |
+ int |
+ 经过多少个bars后,才确认突破。默认为1 |
+ 1 |
+
Returns:
+Type | +Description | +
---|---|
BreakoutFlag |
+ 如果上向突破压力线,返回1,如果向下突破压力线,返回-1,否则返回0 |
+
omicron/talib/morph.py
def breakout(
+ ts: np.ndarray, upthres: float = 0.01, downthres: float = -0.01, confirm: int = 1
+) -> BreakoutFlag:
+ """检测时间序列是否突破了压力线(整理线)
+
+ Args:
+ ts (np.ndarray): 时间序列
+ upthres (float, optional): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+ downthres (float, optional): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+ confirm (int, optional): 经过多少个bars后,才确认突破。默认为1
+
+ Returns:
+ 如果上向突破压力线,返回1,如果向下突破压力线,返回-1,否则返回0
+ """
+ support, resist, _ = support_resist_lines(ts[:-confirm], upthres, downthres)
+
+ x0 = len(ts) - confirm - 1
+ x = list(range(len(ts) - confirm, len(ts)))
+
+ if resist is not None:
+ if np.all(ts[x] > resist(x)) and ts[x0] <= resist(x0):
+ return BreakoutFlag.UP
+
+ if support is not None:
+ if np.all(ts[x] < support(x)) and ts[x0] >= support(x0):
+ return BreakoutFlag.DOWN
+
+ return BreakoutFlag.NONE
+
cross(f, g)
+
+
+¶判断序列f是否与g相交。如果两个序列有且仅有一个交点,则返回1表明f上交g;-1表明f下交g
+本方法可用以判断两条均线是否相交。
+ +Returns:
+Type | +Description | +
---|---|
CrossFlag |
+ (flag, index), 其中flag取值为: +0 无效 +-1 f向下交叉g +1 f向上交叉g |
+
omicron/talib/morph.py
def cross(f: np.ndarray, g: np.ndarray) -> CrossFlag:
+ """判断序列f是否与g相交。如果两个序列有且仅有一个交点,则返回1表明f上交g;-1表明f下交g
+
+ 本方法可用以判断两条均线是否相交。
+
+ returns:
+ (flag, index), 其中flag取值为:
+ 0 无效
+ -1 f向下交叉g
+ 1 f向上交叉g
+ """
+ indices = np.argwhere(np.diff(np.sign(f - g))).flatten()
+
+ if len(indices) == 0:
+ return CrossFlag.NONE, 0
+
+ # 如果存在一个或者多个交点,取最后一个
+ idx = indices[-1]
+
+ if f[idx] < g[idx]:
+ return CrossFlag.UPCROSS, idx
+ elif f[idx] > g[idx]:
+ return CrossFlag.DOWNCROSS, idx
+ else:
+ return CrossFlag(np.sign(g[idx - 1] - f[idx - 1])), idx
+
energy_hump(bars, thresh=2)
+
+
+¶检测bars
中是否存在两波以上量能剧烈增加的情形(能量驼峰),返回最后一波距现在的位置及区间长度。
注意如果最后一个能量驼峰距现在过远(比如超过10个bar),可能意味着资金已经逃离,能量已经耗尽。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
bars |
+ [('frame', '<M8[s]'), ('open', '<f4'), ('high', '<f4'), ('low', '<f4'), ('close', '<f4'), ('volume', '<f8'), ('amount', '<f8'), ('factor', '<f4')] |
+ 行情数据 |
+ required | +
thresh |
+ + | 最后一波量必须大于20天均量的倍数。 |
+ 2 |
+
Returns:
+Type | +Description | +
---|---|
Optional[Tuple[int, int]] |
+ 如果不存在能量驼峰的情形,则返回None,否则返回最后一个驼峰离现在的距离及区间长度。 |
+
omicron/talib/morph.py
def energy_hump(bars: bars_dtype, thresh=2) -> Optional[Tuple[int, int]]:
+ """检测`bars`中是否存在两波以上量能剧烈增加的情形(能量驼峰),返回最后一波距现在的位置及区间长度。
+
+ 注意如果最后一个能量驼峰距现在过远(比如超过10个bar),可能意味着资金已经逃离,能量已经耗尽。
+
+ Args:
+ bars: 行情数据
+ thresh: 最后一波量必须大于20天均量的倍数。
+ Returns:
+ 如果不存在能量驼峰的情形,则返回None,否则返回最后一个驼峰离现在的距离及区间长度。
+ """
+ vol = bars["volume"]
+
+ std = np.std(vol[1:] / vol[:-1])
+ pvs = peak_valley_pivots(vol, std, 0)
+
+ frames = bars["frame"]
+
+ pvs[0] = 0
+ pvs[-1] = -1
+ peaks = np.argwhere(pvs == 1)
+
+ mn = np.mean(vol[peaks])
+
+ # 顶点不能缩量到尖峰均值以下
+ real_peaks = np.intersect1d(np.argwhere(vol > mn), peaks)
+
+ if len(real_peaks) < 2:
+ return None
+
+ logger.debug("found %s peaks at %s", len(real_peaks), frames[real_peaks])
+ lp = real_peaks[-1]
+ ma = moving_average(vol, 20)[lp]
+ if vol[lp] < ma * thresh:
+ logger.debug(
+ "vol of last peak[%s] is less than mean_vol(20) * thresh[%s]",
+ vol[lp],
+ ma * thresh,
+ )
+ return None
+
+ return len(bars) - real_peaks[-1], real_peaks[-1] - real_peaks[0]
+
inverse_vcross(f, g)
+
+
+¶判断序列f是否与序列g存在^型相交。即存在两个交点,第一个交点为向上相交,第二个交点为向下 +相交。可用于判断见顶特征等场合。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+ np.array |
+ [description] |
+ required | +
g |
+ np.array |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple |
+ [description] |
+
omicron/talib/morph.py
def inverse_vcross(f: np.array, g: np.array) -> Tuple:
+ """判断序列f是否与序列g存在^型相交。即存在两个交点,第一个交点为向上相交,第二个交点为向下
+ 相交。可用于判断见顶特征等场合。
+
+ Args:
+ f (np.array): [description]
+ g (np.array): [description]
+
+ Returns:
+ Tuple: [description]
+ """
+ indices = np.argwhere(np.diff(np.sign(f - g))).flatten()
+ if len(indices) == 2:
+ idx0, idx1 = indices
+ if f[idx0] < g[idx0] and f[idx1] > g[idx1]:
+ return True, (idx0, idx1)
+
+ return False, (None, None)
+
peaks_and_valleys(ts, up_thresh=None, down_thresh=None)
+
+
+¶寻找ts中的波峰和波谷,返回数组指示在该位置上是否为波峰或波谷。如果为1,则为波峰;如果为-1,则为波谷。
+本函数直接使用了zigzag中的peak_valley_pivots. 有很多方法可以实现本功能,比如scipy.signals.find_peaks_cwt, peak_valley_pivots等。本函数更适合金融时间序列,并且使用了cython加速。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.ndarray |
+ 时间序列 |
+ required | +
up_thresh |
+ float |
+ 波峰的阈值,如果为None,则使用ts变化率的二倍标准差 |
+ None |
+
down_thresh |
+ float |
+ 波谷的阈值,如果为None,则使用ts变化率的二倍标准差乘以-1 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
np.ndarray |
+ 返回数组指示在该位置上是否为波峰或波谷。 |
+
omicron/talib/morph.py
def peaks_and_valleys(
+ ts: np.ndarray,
+ up_thresh: Optional[float] = None,
+ down_thresh: Optional[float] = None,
+) -> np.ndarray:
+ """寻找ts中的波峰和波谷,返回数组指示在该位置上是否为波峰或波谷。如果为1,则为波峰;如果为-1,则为波谷。
+
+ 本函数直接使用了zigzag中的peak_valley_pivots. 有很多方法可以实现本功能,比如scipy.signals.find_peaks_cwt, peak_valley_pivots等。本函数更适合金融时间序列,并且使用了cython加速。
+
+ Args:
+ ts (np.ndarray): 时间序列
+ up_thresh (float): 波峰的阈值,如果为None,则使用ts变化率的二倍标准差
+ down_thresh (float): 波谷的阈值,如果为None,则使用ts变化率的二倍标准差乘以-1
+
+ Returns:
+ np.ndarray: 返回数组指示在该位置上是否为波峰或波谷。
+ """
+ if ts.dtype != np.float64:
+ ts = ts.astype(np.float64)
+
+ if any([up_thresh is None, down_thresh is None]):
+ change_rate = ts[1:] / ts[:-1] - 1
+ std = np.std(change_rate)
+ up_thresh = up_thresh or 2 * std
+ down_thresh = down_thresh or -2 * std
+
+ return peak_valley_pivots(ts, up_thresh, down_thresh)
+
plateaus(numbers, min_size, fall_in_range_ratio=0.97)
+
+
+¶统计数组numbers
中的可能存在的平台整理。
如果一个数组中存在着子数组,使得其元素与均值的距离落在三个标准差以内的比例超过fall_in_range_ratio
的,则认为该子数组满足平台整理。
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
numbers |
+ ndarray |
+ 输入数组 |
+ required | +
min_size |
+ int |
+ 平台的最小长度 |
+ required | +
fall_in_range_ratio |
+ float |
+ 超过 |
+ 0.97 |
+
Returns:
+Type | +Description | +
---|---|
List[Tuple] |
+ 平台的起始位置和长度的数组 |
+
omicron/talib/morph.py
def plateaus(
+ numbers: np.ndarray, min_size: int, fall_in_range_ratio: float = 0.97
+) -> List[Tuple]:
+ """统计数组`numbers`中的可能存在的平台整理。
+
+ 如果一个数组中存在着子数组,使得其元素与均值的距离落在三个标准差以内的比例超过`fall_in_range_ratio`的,则认为该子数组满足平台整理。
+
+ Args:
+ numbers: 输入数组
+ min_size: 平台的最小长度
+ fall_in_range_ratio: 超过`fall_in_range_ratio`比例的元素落在均值的三个标准差以内,就认为该子数组构成一个平台
+
+ Returns:
+ 平台的起始位置和长度的数组
+ """
+ if numbers.size <= min_size:
+ n = 1
+ else:
+ n = numbers.size // min_size
+
+ clusters = clustering(numbers, n)
+
+ plats = []
+ for (start, length) in clusters:
+ if length < min_size:
+ continue
+
+ y = numbers[start : start + length]
+ mean = np.mean(y)
+ std = np.std(y)
+
+ inrange = len(y[np.abs(y - mean) < 3 * std])
+ ratio = inrange / length
+
+ if ratio >= fall_in_range_ratio:
+ plats.append((start, length))
+
+ return plats
+
rsi_bottom_distance(close, thresh=None)
+
+
+¶根据给定的收盘价,计算最后一个数据到上一个发出rsi低水平的距离, +如果从上一个最低点rsi到最后一个数据并未发出低水平信号, +返回最后一个数据到上一个发出最低点rsi的距离。
+其中close的长度一般不小于60。 +返回值为距离整数,不满足条件则返回None。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.array |
+ 具有时间序列的收盘价 |
+ required | +
thresh |
+ Tuple[float, float]) |
+ None适用所有股票,不必更改,也可自行设置。 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 返回最后一个数据到上一个发出rsi低水平的距离。 +如果从上一个最低点rsi到最后一个数据并未发出低水平信号, +返回最后一个数据到上一个发出最低点rsi的距离。 +除此之外,返回None。 |
+
omicron/talib/morph.py
def rsi_bottom_distance(close: np.array, thresh: Tuple[float, float] = None) -> int:
+ """根据给定的收盘价,计算最后一个数据到上一个发出rsi低水平的距离,
+ 如果从上一个最低点rsi到最后一个数据并未发出低水平信号,
+ 返回最后一个数据到上一个发出最低点rsi的距离。
+
+ 其中close的长度一般不小于60。
+ 返回值为距离整数,不满足条件则返回None。
+
+ Args:
+ close (np.array): 具有时间序列的收盘价
+ thresh (Tuple[float, float]) : None适用所有股票,不必更改,也可自行设置。
+
+ Returns:
+ 返回最后一个数据到上一个发出rsi低水平的距离。
+ 如果从上一个最低点rsi到最后一个数据并未发出低水平信号,
+ 返回最后一个数据到上一个发出最低点rsi的距离。
+ 除此之外,返回None。"""
+
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ rsi = ta.RSI(close, 6)
+
+ watermarks = rsi_watermarks(close, thresh)
+ if watermarks is not None:
+ low_watermark, _, _ = watermarks
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivots[0], pivots[-1] = 0, 0
+
+ # 谷值RSI<30
+ valley_rsi_index = np.where((rsi < 30) & (pivots == -1))[0]
+
+ # RSI低水平的最大值:低水平*1.01
+ low_rsi_index = np.where(rsi <= low_watermark * 1.01)[0]
+
+ if len(valley_rsi_index) > 0:
+ distance = len(rsi) - 1 - valley_rsi_index[-1]
+ if len(low_rsi_index) > 0:
+ if low_rsi_index[-1] >= valley_rsi_index[-1]:
+ distance = len(rsi) - 1 - low_rsi_index[-1]
+ return distance
+
rsi_bottom_divergent(close, thresh=None, rsi_limit=30)
+
+
+¶寻找最近满足条件的rsi底背离。
+返回最后一个数据到最近底背离发生点的距离;没有满足条件的底背离,返回None。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.array |
+ 时间序列收盘价 |
+ required | +
thresh |
+ Tuple[float, float] |
+ + | None |
+
rsi_limit |
+ float |
+ RSI发生底背离时的阈值, 默认值30(20效果更佳,但是检测出来数量太少),即只过滤RSI6<30的局部最低收盘价。 |
+ 30 |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 返回int类型的整数,表示最后一个数据到最近底背离发生点的距离;没有满足条件的底背离,返回None。 |
+
omicron/talib/morph.py
def rsi_bottom_divergent(
+ close: np.array, thresh: Tuple[float, float] = None, rsi_limit: float = 30
+) -> int:
+ """寻找最近满足条件的rsi底背离。
+
+ 返回最后一个数据到最近底背离发生点的距离;没有满足条件的底背离,返回None。
+
+ Args:
+ close (np.array): 时间序列收盘价
+ thresh (Tuple[float, float]): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+ rsi_limit (float, optional): RSI发生底背离时的阈值, 默认值30(20效果更佳,但是检测出来数量太少),即只过滤RSI6<30的局部最低收盘价。
+
+ Returns:
+ 返回int类型的整数,表示最后一个数据到最近底背离发生点的距离;没有满足条件的底背离,返回None。
+ """
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+ rsi = ta.RSI(close, 6)
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivots[0], pivots[-1] = 0, 0
+
+ length = len(close)
+ valley_index = np.where((pivots == -1) & (rsi <= rsi_limit))[0]
+
+ if len(valley_index) >= 2:
+ if (close[valley_index[-1]] < close[valley_index[-2]]) and (
+ rsi[valley_index[-1]] > rsi[valley_index[-2]]
+ ):
+ bottom_dev_distance = length - 1 - valley_index[-1]
+
+ return bottom_dev_distance
+
rsi_predict_price(close, thresh=None)
+
+
+¶给定一段行情,根据最近的两个RSI的极小值和极大值预测下一个周期可能达到的最低价格和最高价格。
+其原理是,以预测最近的两个最高价和最低价,求出其相对应的RSI值,求出最高价和最低价RSI的均值, +若只有一个则取最近的一个。再由RSI公式,反推价格。此时返回值为(None, float),即只有最高价,没有最低价。反之亦然。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.ndarray |
+ 具有时间序列的收盘价 |
+ required | +
thresh |
+ Tuple[float, float]) |
+ + | None |
+
Returns:
+Type | +Description | +
---|---|
Tuple[float, float] |
+ 返回数组[predicted_low_price, predicted_high_price], 数组第一个值为利用达到之前最低收盘价的RSI预测的最低价。 +第二个值为利用达到之前最高收盘价的RSI预测的最高价。 |
+
omicron/talib/morph.py
def rsi_predict_price(
+ close: np.ndarray, thresh: Tuple[float, float] = None
+) -> Tuple[float, float]:
+ """给定一段行情,根据最近的两个RSI的极小值和极大值预测下一个周期可能达到的最低价格和最高价格。
+
+ 其原理是,以预测最近的两个最高价和最低价,求出其相对应的RSI值,求出最高价和最低价RSI的均值,
+ 若只有一个则取最近的一个。再由RSI公式,反推价格。此时返回值为(None, float),即只有最高价,没有最低价。反之亦然。
+
+ Args:
+ close (np.ndarray): 具有时间序列的收盘价
+ thresh (Tuple[float, float]) : 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+
+ Returns:
+ 返回数组[predicted_low_price, predicted_high_price], 数组第一个值为利用达到之前最低收盘价的RSI预测的最低价。
+ 第二个值为利用达到之前最高收盘价的RSI预测的最高价。
+ """
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+
+ valley_rsi, peak_rsi, _ = rsi_watermarks(close, thresh=thresh)
+ pivot = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivot[0], pivot[-1] = 0, 0 # 掐头去尾
+
+ price_change = pd.Series(close).diff(1).values
+ ave_price_change = (abs(price_change)[-6:].mean()) * 5
+ ave_price_raise = (np.maximum(price_change, 0)[-6:].mean()) * 5
+
+ if valley_rsi is not None:
+ predicted_low_change = (ave_price_change) - ave_price_raise / (
+ 0.01 * valley_rsi
+ )
+ if predicted_low_change > 0:
+ predicted_low_change = 0
+ predicted_low_price = close[-1] + predicted_low_change
+ else:
+ predicted_low_price = None
+
+ if peak_rsi is not None:
+ predicted_high_change = (ave_price_raise - ave_price_change) / (
+ 0.01 * peak_rsi - 1
+ ) - ave_price_change
+ if predicted_high_change < 0:
+ predicted_high_change = 0
+ predicted_high_price = close[-1] + predicted_high_change
+ else:
+ predicted_high_price = None
+
+ return predicted_low_price, predicted_high_price
+
rsi_top_distance(close, thresh=None)
+
+
+¶根据给定的收盘价,计算最后一个数据到上一个发出rsi高水平的距离, +如果从上一个最高点rsi到最后一个数据并未发出高水平信号, +返回最后一个数据到上一个发出最高点rsi的距离。
+其中close的长度一般不小于60。 +返回值为距离整数,不满足条件则返回None。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.array |
+ 具有时间序列的收盘价 |
+ required | +
thresh |
+ Tuple[float, float]) |
+ None适用所有股票,不必更改,也可自行设置。 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 返回最后一个数据到上一个发出rsi高水平的距离。 +如果从上一个最高点rsi到最后一个数据并未发出高水平信号, +返回最后一个数据到上一个发出最高点rsi的距离。 +除此之外,返回None。 |
+
omicron/talib/morph.py
def rsi_top_distance(close: np.array, thresh: Tuple[float, float] = None) -> int:
+ """根据给定的收盘价,计算最后一个数据到上一个发出rsi高水平的距离,
+ 如果从上一个最高点rsi到最后一个数据并未发出高水平信号,
+ 返回最后一个数据到上一个发出最高点rsi的距离。
+
+ 其中close的长度一般不小于60。
+ 返回值为距离整数,不满足条件则返回None。
+
+ Args:
+ close (np.array): 具有时间序列的收盘价
+ thresh (Tuple[float, float]) : None适用所有股票,不必更改,也可自行设置。
+
+ Returns:
+ 返回最后一个数据到上一个发出rsi高水平的距离。
+ 如果从上一个最高点rsi到最后一个数据并未发出高水平信号,
+ 返回最后一个数据到上一个发出最高点rsi的距离。
+ 除此之外,返回None。"""
+
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ rsi = ta.RSI(close, 6)
+
+ watermarks = rsi_watermarks(close, thresh)
+ if watermarks is not None:
+ _, high_watermark, _ = watermarks
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivots[0], pivots[-1] = 0, 0
+
+ # 峰值RSI>70
+ peak_rsi_index = np.where((rsi > 70) & (pivots == 1))[0]
+
+ # RSI高水平的最小值:高水平*0.99
+ high_rsi_index = np.where(rsi >= high_watermark * 0.99)[0]
+
+ if len(peak_rsi_index) > 0:
+ distance = len(rsi) - 1 - peak_rsi_index[-1]
+ if len(high_rsi_index) > 0:
+ if high_rsi_index[-1] >= peak_rsi_index[-1]:
+ distance = len(rsi) - 1 - high_rsi_index[-1]
+ return distance
+
rsi_top_divergent(close, thresh=None, rsi_limit=70)
+
+
+¶寻找最近满足条件的rsi顶背离。
+返回最后一个数据到最近顶背离发生点的距离;没有满足条件的顶背离,返回None。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.array |
+ 时间序列收盘价 |
+ required | +
thresh |
+ Tuple[float, float] |
+ + | None |
+
rsi_limit |
+ float |
+ RSI发生顶背离时的阈值, 默认值70(80效果更佳,但是检测出来数量太少),即只过滤RSI6>70的局部最高收盘价。 |
+ 70 |
+
Returns:
+Type | +Description | +
---|---|
Tuple[int, int] |
+ 返回int类型的整数,表示最后一个数据到最近顶背离发生点的距离;没有满足条件的顶背离,返回None。 |
+
omicron/talib/morph.py
def rsi_top_divergent(
+ close: np.array, thresh: Tuple[float, float] = None, rsi_limit: float = 70
+) -> Tuple[int, int]:
+ """寻找最近满足条件的rsi顶背离。
+
+ 返回最后一个数据到最近顶背离发生点的距离;没有满足条件的顶背离,返回None。
+
+ Args:
+ close (np.array): 时间序列收盘价
+ thresh (Tuple[float, float]): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+ rsi_limit (float, optional): RSI发生顶背离时的阈值, 默认值70(80效果更佳,但是检测出来数量太少),即只过滤RSI6>70的局部最高收盘价。
+
+ Returns:
+ 返回int类型的整数,表示最后一个数据到最近顶背离发生点的距离;没有满足条件的顶背离,返回None。
+ """
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+ rsi = ta.RSI(close, 6)
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivots[0], pivots[-1] = 0, 0
+
+ length = len(close)
+ peak_index = np.where((pivots == 1) & (rsi >= rsi_limit))[0]
+
+ if len(peak_index) >= 2:
+ if (close[peak_index[-1]] > close[peak_index[-2]]) and (
+ rsi[peak_index[-1]] < rsi[peak_index[-2]]
+ ):
+ top_dev_distance = length - 1 - peak_index[-1]
+
+ return top_dev_distance
+
rsi_watermarks(close, thresh=None)
+
+
+¶给定一段行情数据和用以检测顶和底的阈值,返回该段行情中,谷和峰处RSI均值,最后一个RSI6值。
+其中close的长度一般不小于60,不大于120。返回值中,一个为low_wartermark(谷底处RSI值), +一个为high_wartermark(高峰处RSI值),一个为RSI6的最后一个值,用以对比前两个警戒值。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.array |
+ 具有时间序列的收盘价 |
+ required | +
thresh |
+ Tuple[float, float]) |
+ None适用所有股票,不必更改,也可自行设置。 |
+ None |
+
Returns:
+Type | +Description | +
---|---|
Tuple[float, float, float] |
+ 返回数组[low_watermark, high_watermark, rsi[-1]], 第一个为最近两个最低收盘价的RSI均值, 第二个为最近两个最高收盘价的RSI均值。 +若传入收盘价只有一个最值,只返回一个。没有最值,则返回None, 第三个为实际的最后RSI6的值。 |
+
omicron/talib/morph.py
def rsi_watermarks(
+ close: np.array, thresh: Tuple[float, float] = None
+) -> Tuple[float, float, float]:
+ """给定一段行情数据和用以检测顶和底的阈值,返回该段行情中,谷和峰处RSI均值,最后一个RSI6值。
+
+ 其中close的长度一般不小于60,不大于120。返回值中,一个为low_wartermark(谷底处RSI值),
+ 一个为high_wartermark(高峰处RSI值),一个为RSI6的最后一个值,用以对比前两个警戒值。
+
+ Args:
+ close (np.array): 具有时间序列的收盘价
+ thresh (Tuple[float, float]) : None适用所有股票,不必更改,也可自行设置。
+
+ Returns:
+ 返回数组[low_watermark, high_watermark, rsi[-1]], 第一个为最近两个最低收盘价的RSI均值, 第二个为最近两个最高收盘价的RSI均值。
+ 若传入收盘价只有一个最值,只返回一个。没有最值,则返回None, 第三个为实际的最后RSI6的值。
+ """
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+ rsi = ta.RSI(close, 6)
+
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ pivots[0], pivots[-1] = 0, 0 # 掐头去尾
+
+ # 峰值RSI>70; 谷处的RSI<30;
+ peaks_rsi_index = np.where((rsi > 70) & (pivots == 1))[0]
+ valleys_rsi_index = np.where((rsi < 30) & (pivots == -1))[0]
+
+ if len(peaks_rsi_index) == 0:
+ high_watermark = None
+ elif len(peaks_rsi_index) == 1:
+ high_watermark = rsi[peaks_rsi_index[0]]
+ else: # 有两个以上的峰,通过最近的两个峰均值来确定走势
+ high_watermark = np.nanmean(rsi[peaks_rsi_index[-2:]])
+
+ if len(valleys_rsi_index) == 0:
+ low_watermark = None
+ elif len(valleys_rsi_index) == 1:
+ low_watermark = rsi[valleys_rsi_index[0]]
+ else: # 有两个以上的峰,通过最近的两个峰来确定走势
+ low_watermark = np.nanmean(rsi[valleys_rsi_index[-2:]])
+
+ return low_watermark, high_watermark, rsi[-1]
+
support_resist_lines(ts, upthres=None, downthres=None)
+
+
+¶计算时间序列的支撑线和阻力线
+使用最近的两个高点连接成阴力线,两个低点连接成支撑线。
+ +Examples:
+1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 +10 +11 +12 +13 +14 +15 +16 +17 |
|
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ts |
+ np.ndarray |
+ 时间序列 |
+ required | +
upthres |
+ float |
+ + | None |
+
downthres |
+ float |
+ + | None |
+
Returns:
+Type | +Description | +
---|---|
Tuple[Callable, Callable, numpy.ndarray] |
+ 返回支撑线和阻力线的计算函数及起始点坐标,如果没有支撑线或阻力线,则返回None |
+
omicron/talib/morph.py
def support_resist_lines(
+ ts: np.ndarray, upthres: float = None, downthres: float = None
+) -> Tuple[Callable, Callable, np.ndarray]:
+ """计算时间序列的支撑线和阻力线
+
+ 使用最近的两个高点连接成阴力线,两个低点连接成支撑线。
+
+ Examples:
+ ```python
+ def show_support_resist_lines(ts):
+ import plotly.graph_objects as go
+
+ fig = go.Figure()
+
+ support, resist, x_start = support_resist_lines(ts, 0.03, -0.03)
+ fig.add_trace(go.Scatter(x=np.arange(len(ts)), y=ts))
+
+ x = np.arange(len(ts))[x_start:]
+ fig.add_trace(go.Line(x=x, y = support(x)))
+ fig.add_trace(go.Line(x=x, y = resist(x)))
+
+ fig.show()
+
+ np.random.seed(1978)
+ X = np.cumprod(1 + np.random.randn(100) * 0.01)
+ show_support_resist_lines(X)
+ ```
+ the above code will show this ![](https://images.jieyu.ai/images/202204/support_resist.png)
+
+ Args:
+ ts (np.ndarray): 时间序列
+ upthres (float, optional): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+ downthres (float, optional): 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+
+ Returns:
+ 返回支撑线和阻力线的计算函数及起始点坐标,如果没有支撑线或阻力线,则返回None
+ """
+ if ts.dtype != np.float64:
+ ts = ts.astype(np.float64)
+
+ pivots = peaks_and_valleys(ts, upthres, downthres)
+ pivots[0] = 0
+ pivots[-1] = 0
+
+ arg_max = np.argwhere(pivots == 1).flatten()
+ arg_min = np.argwhere(pivots == -1).flatten()
+
+ resist = None
+ support = None
+
+ if len(arg_max) >= 2:
+ arg_max = arg_max[-2:]
+ y = ts[arg_max]
+ coeff = np.polyfit(arg_max, y, deg=1)
+
+ resist = np.poly1d(coeff)
+
+ if len(arg_min) >= 2:
+ arg_min = arg_min[-2:]
+ y = ts[arg_min]
+ coeff = np.polyfit(arg_min, y, deg=1)
+
+ support = np.poly1d(coeff)
+
+ return support, resist, np.min([*arg_min, *arg_max])
+
valley_detect(close, thresh=(0.05, -0.02))
+
+
+¶给定一段行情数据和用以检测近期已发生反转的最低点,返回该段行情中,最低点到最后一个数据的距离和收益率数组, +如果给定行情中未找到满足参数的最低点,则返回两个空值数组。
+其中bars的长度一般不小于60,不大于120。此函数采用了zigzag中的谷峰检测方法,其中参数默认(0.05,-0.02), +此参数对所有股票数据都适用。若满足参数,返回值中,距离为大于0的整数,收益率是0~1的小数。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
close |
+ np.ndarray |
+ 具有时间序列的收盘价 |
+ required | +
thresh |
+ Tuple[float, float]) |
+ + | (0.05, -0.02) |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 返回该段行情中,最低点到最后一个数据的距离和收益率数组, +如果给定行情中未找到满足参数的最低点,则返回两个空值数组。 |
+
omicron/talib/morph.py
def valley_detect(
+ close: np.ndarray, thresh: Tuple[float, float] = (0.05, -0.02)
+) -> int:
+ """给定一段行情数据和用以检测近期已发生反转的最低点,返回该段行情中,最低点到最后一个数据的距离和收益率数组,
+ 如果给定行情中未找到满足参数的最低点,则返回两个空值数组。
+
+ 其中bars的长度一般不小于60,不大于120。此函数采用了zigzag中的谷峰检测方法,其中参数默认(0.05,-0.02),
+ 此参数对所有股票数据都适用。若满足参数,返回值中,距离为大于0的整数,收益率是0~1的小数。
+
+ Args:
+ close (np.ndarray): 具有时间序列的收盘价
+ thresh (Tuple[float, float]) : 请参考[peaks_and_valleys][omicron.talib.morph.peaks_and_valleys]
+
+ Returns:
+ 返回该段行情中,最低点到最后一个数据的距离和收益率数组,
+ 如果给定行情中未找到满足参数的最低点,则返回两个空值数组。
+ """
+
+ assert len(close) >= 60, "must provide an array with at least 60 length!"
+
+ if close.dtype != np.float64:
+ close = close.astype(np.float64)
+
+ if thresh is None:
+ std = np.std(close[-59:] / close[-60:-1] - 1)
+ thresh = (2 * std, -2 * std)
+
+ pivots = peak_valley_pivots(close, thresh[0], thresh[1])
+ flags = pivots[pivots != 0]
+ increased = None
+ lowest_distance = None
+ if (flags[-2] == -1) and (flags[-1] == 1):
+ length = len(pivots)
+ valley_index = np.where(pivots == -1)[0]
+ increased = (close[-1] - close[valley_index[-1]]) / close[valley_index[-1]]
+ lowest_distance = int(length - 1 - valley_index[-1])
+
+ return lowest_distance, increased
+
vcross(f, g)
+
+
+¶判断序列f是否与g存在类型v型的相交。即存在两个交点,第一个交点为向下相交,第二个交点为向上 +相交。一般反映为洗盘拉升的特征。
+ +Examples:
+>>> f = np.array([ 3 * i ** 2 - 20 * i + 2 for i in range(10)])
+>>> g = np.array([ i - 5 for i in range(10)])
+>>> flag, indices = vcross(f, g)
+>>> assert flag is True
+>>> assert indices[0] == 0
+>>> assert indices[1] == 6
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
f |
+ <built-in function array> |
+ first sequence |
+ required | +
g |
+ <built-in function array> |
+ the second sequence |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple |
+ (flag, indices), 其中flag取值为True时,存在vcross,indices为交点的索引。 |
+
omicron/talib/morph.py
def vcross(f: np.array, g: np.array) -> Tuple:
+ """判断序列f是否与g存在类型v型的相交。即存在两个交点,第一个交点为向下相交,第二个交点为向上
+ 相交。一般反映为洗盘拉升的特征。
+
+ Examples:
+
+ >>> f = np.array([ 3 * i ** 2 - 20 * i + 2 for i in range(10)])
+ >>> g = np.array([ i - 5 for i in range(10)])
+ >>> flag, indices = vcross(f, g)
+ >>> assert flag is True
+ >>> assert indices[0] == 0
+ >>> assert indices[1] == 6
+
+ Args:
+ f: first sequence
+ g: the second sequence
+
+ Returns:
+ (flag, indices), 其中flag取值为True时,存在vcross,indices为交点的索引。
+ """
+ indices = np.argwhere(np.diff(np.sign(f - g))).flatten()
+ if len(indices) == 2:
+ idx0, idx1 = indices
+ if f[idx0] > g[idx0] and f[idx1] < g[idx1]:
+ return True, (idx0, idx1)
+
+ return False, (None, None)
+
+TimeFrame
+
+
+
+¶omicron/models/timeframe.py
class TimeFrame:
+ minute_level_frames = [
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ ]
+ day_level_frames = [
+ FrameType.DAY,
+ FrameType.WEEK,
+ FrameType.MONTH,
+ FrameType.QUARTER,
+ FrameType.YEAR,
+ ]
+
+ ticks = {
+ FrameType.MIN1: [i for i in itertools.chain(range(571, 691), range(781, 901))],
+ FrameType.MIN5: [
+ i for i in itertools.chain(range(575, 695, 5), range(785, 905, 5))
+ ],
+ FrameType.MIN15: [
+ i for i in itertools.chain(range(585, 705, 15), range(795, 915, 15))
+ ],
+ FrameType.MIN30: [
+ int(s[:2]) * 60 + int(s[2:])
+ for s in ["1000", "1030", "1100", "1130", "1330", "1400", "1430", "1500"]
+ ],
+ FrameType.MIN60: [
+ int(s[:2]) * 60 + int(s[2:]) for s in ["1030", "1130", "1400", "1500"]
+ ],
+ }
+ day_frames = None
+ week_frames = None
+ month_frames = None
+ quarter_frames = None
+ year_frames = None
+
+ @classmethod
+ def service_degrade(cls):
+ """当cache中不存在日历时,启用随omicron版本一起发行时自带的日历。
+
+ 注意:随omicron版本一起发行时自带的日历很可能不是最新的,并且可能包含错误。比如,存在这样的情况,在本版本的omicron发行时,日历更新到了2021年12月31日,在这之前的日历都是准确的,但在此之后的日历,则有可能出现错误。因此,只应该在特殊的情况下(比如测试)调用此方法,以获得一个降级的服务。
+ """
+ _dir = os.path.dirname(__file__)
+ file = os.path.join(_dir, "..", "config", "calendar.json")
+ with open(file, "r") as f:
+ data = json.load(f)
+ for k, v in data.items():
+ setattr(cls, k, np.array(v))
+
+ @classmethod
+ async def _load_calendar(cls):
+ """从数据缓存中加载更新日历"""
+ from omicron import cache
+
+ names = [
+ "day_frames",
+ "week_frames",
+ "month_frames",
+ "quarter_frames",
+ "year_frames",
+ ]
+ for name, frame_type in zip(names, cls.day_level_frames):
+ key = f"calendar:{frame_type.value}"
+ result = await cache.security.lrange(key, 0, -1)
+ if result is not None and len(result):
+ frames = [int(x) for x in result]
+ setattr(cls, name, np.array(frames))
+ else: # pragma: no cover
+ raise DataNotReadyError(f"calendar data is not ready: {name} missed")
+
+ @classmethod
+ async def init(cls):
+ """初始化日历"""
+ await cls._load_calendar()
+
+ @classmethod
+ def int2time(cls, tm: int) -> datetime.datetime:
+ """将整数表示的时间转换为`datetime`类型表示
+
+ examples:
+ >>> TimeFrame.int2time(202005011500)
+ datetime.datetime(2020, 5, 1, 15, 0)
+
+ Args:
+ tm: time in YYYYMMDDHHmm format
+
+ Returns:
+ 转换后的时间
+ """
+ s = str(tm)
+ # its 8 times faster than arrow.get()
+ return datetime.datetime(
+ int(s[:4]), int(s[4:6]), int(s[6:8]), int(s[8:10]), int(s[10:12])
+ )
+
+ @classmethod
+ def time2int(cls, tm: Union[datetime.datetime, Arrow]) -> int:
+ """将时间类型转换为整数类型
+
+ tm可以是Arrow类型,也可以是datetime.datetime或者任何其它类型,只要它有year,month...等
+ 属性
+ Examples:
+ >>> TimeFrame.time2int(datetime.datetime(2020, 5, 1, 15))
+ 202005011500
+
+ Args:
+ tm:
+
+ Returns:
+ 转换后的整数,比如2020050115
+ """
+ return int(f"{tm.year:04}{tm.month:02}{tm.day:02}{tm.hour:02}{tm.minute:02}")
+
+ @classmethod
+ def date2int(cls, d: Union[datetime.datetime, datetime.date, Arrow]) -> int:
+ """将日期转换为整数表示
+
+ 在zillionare中,如果要对时间和日期进行持久化操作,我们一般将其转换为int类型
+
+ Examples:
+ >>> TimeFrame.date2int(datetime.date(2020,5,1))
+ 20200501
+
+ Args:
+ d: date
+
+ Returns:
+ 日期的整数表示,比如20220211
+ """
+ return int(f"{d.year:04}{d.month:02}{d.day:02}")
+
+ @classmethod
+ def int2date(cls, d: Union[int, str]) -> datetime.date:
+ """将数字表示的日期转换成为日期格式
+
+ Examples:
+ >>> TimeFrame.int2date(20200501)
+ datetime.date(2020, 5, 1)
+
+ Args:
+ d: YYYYMMDD表示的日期
+
+ Returns:
+ 转换后的日期
+ """
+ s = str(d)
+ # it's 8 times faster than arrow.get
+ return datetime.date(int(s[:4]), int(s[4:6]), int(s[6:]))
+
+ @classmethod
+ def day_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """对指定日期进行前后移位操作
+
+ 如果 n == 0,则返回d对应的交易日(如果是非交易日,则返回刚结束的一个交易日)
+ 如果 n > 0,则返回d对应的交易日后第 n 个交易日
+ 如果 n < 0,则返回d对应的交易日前第 n 个交易日
+
+ Examples:
+ >>> TimeFrame.day_frames = [20191212, 20191213, 20191216, 20191217,20191218, 20191219]
+ >>> TimeFrame.day_shift(datetime.date(2019,12,13), 0)
+ datetime.date(2019, 12, 13)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 0)
+ datetime.date(2019, 12, 13)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 1)
+ datetime.date(2019, 12, 16)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 13), 1)
+ datetime.date(2019, 12, 16)
+
+ Args:
+ start: the origin day
+ offset: days to shift, can be negative
+
+ Returns:
+ 移位后的日期
+ """
+ # accelerated from 0.12 to 0.07, per 10000 loop, type conversion time included
+ start = cls.date2int(start)
+
+ return cls.int2date(ext.shift(cls.day_frames, start, offset))
+
+ @classmethod
+ def week_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """对指定日期按周线帧进行前后移位操作
+
+ 参考 [omicron.models.timeframe.TimeFrame.day_shift][]
+ Examples:
+ >>> TimeFrame.week_frames = np.array([20200103, 20200110, 20200117, 20200123,20200207, 20200214])
+ >>> moment = arrow.get('2020-1-21').date()
+ >>> TimeFrame.week_shift(moment, 1)
+ datetime.date(2020, 1, 23)
+
+ >>> TimeFrame.week_shift(moment, 0)
+ datetime.date(2020, 1, 17)
+
+ >>> TimeFrame.week_shift(moment, -1)
+ datetime.date(2020, 1, 10)
+
+ Returns:
+ 移位后的日期
+ """
+ start = cls.date2int(start)
+ return cls.int2date(ext.shift(cls.week_frames, start, offset))
+
+ @classmethod
+ def month_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """求`start`所在的月移位后的frame
+
+ 本函数首先将`start`对齐,然后进行移位。
+ Examples:
+ >>> TimeFrame.month_frames = np.array([20150130, 20150227, 20150331, 20150430])
+ >>> TimeFrame.month_shift(arrow.get('2015-2-26').date(), 0)
+ datetime.date(2015, 1, 30)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-2-27').date(), 0)
+ datetime.date(2015, 2, 27)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 0)
+ datetime.date(2015, 2, 27)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 1)
+ datetime.date(2015, 3, 31)
+
+ Returns:
+ 移位后的日期
+ """
+ start = cls.date2int(start)
+ return cls.int2date(ext.shift(cls.month_frames, start, offset))
+
+ @classmethod
+ def get_ticks(cls, frame_type: FrameType) -> Union[List, np.array]:
+ """取月线、周线、日线及各分钟线对应的frame
+
+ 对分钟线,返回值仅包含时间,不包含日期(均为整数表示)
+
+ Examples:
+ >>> TimeFrame.month_frames = np.array([20050131, 20050228, 20050331])
+ >>> TimeFrame.get_ticks(FrameType.MONTH)[:3]
+ array([20050131, 20050228, 20050331])
+
+ Args:
+ frame_type : [description]
+
+ Raises:
+ ValueError: [description]
+
+ Returns:
+ 月线、周线、日线及各分钟线对应的frame
+ """
+ if frame_type in cls.minute_level_frames:
+ return cls.ticks[frame_type]
+
+ if frame_type == FrameType.DAY:
+ return cls.day_frames
+ elif frame_type == FrameType.WEEK:
+ return cls.week_frames
+ elif frame_type == FrameType.MONTH:
+ return cls.month_frames
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported!")
+
+ @classmethod
+ def shift(
+ cls,
+ moment: Union[Arrow, datetime.date, datetime.datetime],
+ n: int,
+ frame_type: FrameType,
+ ) -> Union[datetime.date, datetime.datetime]:
+ """将指定的moment移动N个`frame_type`位置。
+
+ 当N为负数时,意味着向前移动;当N为正数时,意味着向后移动。如果n为零,意味着移动到最接近
+ 的一个已结束的frame。
+
+ 如果moment没有对齐到frame_type对应的时间,将首先进行对齐。
+
+ See also:
+
+ - [day_shift][omicron.models.timeframe.TimeFrame.day_shift]
+ - [week_shift][omicron.models.timeframe.TimeFrame.week_shift]
+ - [month_shift][omicron.models.timeframe.TimeFrame.month_shift]
+
+ Examples:
+ >>> TimeFrame.shift(datetime.date(2020, 1, 3), 1, FrameType.DAY)
+ datetime.date(2020, 1, 6)
+
+ >>> TimeFrame.shift(datetime.datetime(2020, 1, 6, 11), 1, FrameType.MIN30)
+ datetime.datetime(2020, 1, 6, 11, 30)
+
+
+ Args:
+ moment:
+ n:
+ frame_type:
+
+ Returns:
+ 移位后的Frame
+ """
+ if frame_type == FrameType.DAY:
+ return cls.day_shift(moment, n)
+
+ elif frame_type == FrameType.WEEK:
+ return cls.week_shift(moment, n)
+ elif frame_type == FrameType.MONTH:
+ return cls.month_shift(moment, n)
+ elif frame_type in [
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ ]:
+ tm = moment.hour * 60 + moment.minute
+
+ new_tick_pos = cls.ticks[frame_type].index(tm) + n
+ days = new_tick_pos // len(cls.ticks[frame_type])
+ min_part = new_tick_pos % len(cls.ticks[frame_type])
+
+ date_part = cls.day_shift(moment.date(), days)
+ minutes = cls.ticks[frame_type][min_part]
+ h, m = minutes // 60, minutes % 60
+ return datetime.datetime(
+ date_part.year,
+ date_part.month,
+ date_part.day,
+ h,
+ m,
+ tzinfo=moment.tzinfo,
+ )
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} is not supported.")
+
+ @classmethod
+ def count_day_frames(
+ cls, start: Union[datetime.date, Arrow], end: Union[datetime.date, Arrow]
+ ) -> int:
+ """calc trade days between start and end in close-to-close way.
+
+ if start == end, this will returns 1. Both start/end will be aligned to open
+ trade day before calculation.
+
+ Examples:
+ >>> start = datetime.date(2019, 12, 21)
+ >>> end = datetime.date(2019, 12, 21)
+ >>> TimeFrame.day_frames = [20191219, 20191220, 20191223, 20191224, 20191225]
+ >>> TimeFrame.count_day_frames(start, end)
+ 1
+
+ >>> # non-trade days are removed
+ >>> TimeFrame.day_frames = [20200121, 20200122, 20200123, 20200203, 20200204, 20200205]
+ >>> start = datetime.date(2020, 1, 23)
+ >>> end = datetime.date(2020, 2, 4)
+ >>> TimeFrame.count_day_frames(start, end)
+ 3
+
+ args:
+ start:
+ end:
+ returns:
+ count of days
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+ return int(ext.count_between(cls.day_frames, start, end))
+
+ @classmethod
+ def count_week_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """
+ calc trade weeks between start and end in close-to-close way. Both start and
+ end will be aligned to open trade day before calculation. After that, if start
+ == end, this will returns 1
+
+ for examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+ args:
+ start:
+ end:
+ returns:
+ count of weeks
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+ return int(ext.count_between(cls.week_frames, start, end))
+
+ @classmethod
+ def count_month_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade months between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start:
+ end:
+
+ Returns:
+ months between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.month_frames, start, end))
+
+ @classmethod
+ def count_quarter_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade quarters between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start (datetime.date): [description]
+ end (datetime.date): [description]
+
+ Returns:
+ quarters between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.quarter_frames, start, end))
+
+ @classmethod
+ def count_year_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade years between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start (datetime.date): [description]
+ end (datetime.date): [description]
+
+ Returns:
+ years between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.year_frames, start, end))
+
+ @classmethod
+ def count_frames(
+ cls,
+ start: Union[datetime.date, datetime.datetime, Arrow],
+ end: Union[datetime.date, datetime.datetime, Arrow],
+ frame_type,
+ ) -> int:
+ """计算start与end之间有多少个周期为frame_type的frames
+
+ See also:
+
+ - [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+ - [count_week_frames][omicron.models.timeframe.TimeFrame.count_week_frames]
+ - [count_month_frames][omicron.models.timeframe.TimeFrame.count_month_frames]
+
+ Args:
+ start : start frame
+ end : end frame
+ frame_type : the type of frame
+
+ Raises:
+ ValueError: 如果frame_type不支持,则会抛出此异常。
+
+ Returns:
+ 从start到end的帧数
+ """
+ if frame_type == FrameType.DAY:
+ return cls.count_day_frames(start, end)
+ elif frame_type == FrameType.WEEK:
+ return cls.count_week_frames(start, end)
+ elif frame_type == FrameType.MONTH:
+ return cls.count_month_frames(start, end)
+ elif frame_type == FrameType.QUARTER:
+ return cls.count_quarter_frames(start, end)
+ elif frame_type == FrameType.YEAR:
+ return cls.count_year_frames(start, end)
+ elif frame_type in [
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ ]:
+ tm_start = start.hour * 60 + start.minute
+ tm_end = end.hour * 60 + end.minute
+ days = cls.count_day_frames(start.date(), end.date()) - 1
+
+ tm_start_pos = cls.ticks[frame_type].index(tm_start)
+ tm_end_pos = cls.ticks[frame_type].index(tm_end)
+
+ min_bars = tm_end_pos - tm_start_pos + 1
+
+ return days * len(cls.ticks[frame_type]) + min_bars
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} is not supported yet")
+
+ @classmethod
+ def is_trade_day(cls, dt: Union[datetime.date, datetime.datetime, Arrow]) -> bool:
+ """判断`dt`是否为交易日
+
+ Examples:
+ >>> TimeFrame.is_trade_day(arrow.get('2020-1-1'))
+ False
+
+ Args:
+ dt :
+
+ Returns:
+ bool
+ """
+ return cls.date2int(dt) in cls.day_frames
+
+ @classmethod
+ def is_open_time(cls, tm: Union[datetime.datetime, Arrow] = None) -> bool:
+ """判断`tm`指定的时间是否处在交易时间段。
+
+ 交易时间段是指集合竞价时间段之外的开盘时间
+
+ Examples:
+ >>> TimeFrame.day_frames = np.array([20200102, 20200103, 20200106, 20200107, 20200108])
+ >>> TimeFrame.is_open_time(arrow.get('2020-1-1 14:59').naive)
+ False
+ >>> TimeFrame.is_open_time(arrow.get('2020-1-3 14:59').naive)
+ True
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ tm = tm or arrow.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ tick = tm.hour * 60 + tm.minute
+ return tick in cls.ticks[FrameType.MIN1]
+
+ @classmethod
+ def is_opening_call_auction_time(
+ cls, tm: Union[Arrow, datetime.datetime] = None
+ ) -> bool:
+ """判断`tm`指定的时间是否为开盘集合竞价时间
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ if tm is None:
+ tm = cls.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ minutes = tm.hour * 60 + tm.minute
+ return 9 * 60 + 15 < minutes <= 9 * 60 + 25
+
+ @classmethod
+ def is_closing_call_auction_time(
+ cls, tm: Union[datetime.datetime, Arrow] = None
+ ) -> bool:
+ """判断`tm`指定的时间是否为收盘集合竞价时间
+
+ Fixme:
+ 此处实现有误,收盘集合竞价时间应该还包含上午收盘时间
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ tm = tm or cls.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ minutes = tm.hour * 60 + tm.minute
+ return 15 * 60 - 3 <= minutes < 15 * 60
+
+ @classmethod
+ def floor(cls, moment: Frame, frame_type: FrameType) -> Frame:
+ """求`moment`在指定的`frame_type`中的下界
+
+ 比如,如果`moment`为10:37,则当`frame_type`为30分钟时,对应的上界为10:00
+
+ Examples:
+ >>> # 如果moment为日期,则当成已收盘处理
+ >>> TimeFrame.day_frames = np.array([20050104, 20050105, 20050106, 20050107, 20050110, 20050111])
+ >>> TimeFrame.floor(datetime.date(2005, 1, 7), FrameType.DAY)
+ datetime.date(2005, 1, 7)
+
+ >>> # moment指定的时间还未收盘,floor到上一个交易日
+ >>> TimeFrame.floor(datetime.datetime(2005, 1, 7, 14, 59), FrameType.DAY)
+ datetime.date(2005, 1, 6)
+
+ >>> TimeFrame.floor(datetime.date(2005, 1, 13), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.floor(datetime.date(2005,2, 27), FrameType.MONTH)
+ datetime.date(2005, 1, 31)
+
+ >>> TimeFrame.floor(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+ datetime.datetime(2005, 1, 5, 14, 30)
+
+ >>> TimeFrame.floor(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ >>> TimeFrame.floor(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ Args:
+ moment:
+ frame_type:
+
+ Returns:
+ `moment`在指定的`frame_type`中的下界
+ """
+ if frame_type in cls.minute_level_frames:
+ tm, day_offset = cls.minute_frames_floor(
+ cls.ticks[frame_type], moment.hour * 60 + moment.minute
+ )
+ h, m = tm // 60, tm % 60
+ if cls.day_shift(moment, 0) < moment.date() or day_offset == -1:
+ h = 15
+ m = 0
+ new_day = cls.day_shift(moment, day_offset)
+ else:
+ new_day = moment.date()
+ return datetime.datetime(new_day.year, new_day.month, new_day.day, h, m)
+
+ if type(moment) == datetime.date:
+ moment = datetime.datetime(moment.year, moment.month, moment.day, 15)
+
+ # 如果是交易日,但还未收盘
+ if (
+ cls.date2int(moment) in cls.day_frames
+ and moment.hour * 60 + moment.minute < 900
+ ):
+ moment = cls.day_shift(moment, -1)
+
+ day = cls.date2int(moment)
+ if frame_type == FrameType.DAY:
+ arr = cls.day_frames
+ elif frame_type == FrameType.WEEK:
+ arr = cls.week_frames
+ elif frame_type == FrameType.MONTH:
+ arr = cls.month_frames
+ else: # pragma: no cover
+ raise ValueError(f"frame type {frame_type} not supported.")
+
+ floored = ext.floor(arr, day)
+ return cls.int2date(floored)
+
+ @classmethod
+ def last_min_frame(
+ cls, day: Union[str, Arrow, datetime.date], frame_type: FrameType
+ ) -> Union[datetime.date, datetime.datetime]:
+ """获取`day`日周期为`frame_type`的结束frame。
+
+ Example:
+ >>> TimeFrame.last_min_frame(arrow.get('2020-1-5').date(), FrameType.MIN30)
+ datetime.datetime(2020, 1, 3, 15, 0)
+
+ Args:
+ day:
+ frame_type:
+
+ Returns:
+ `day`日周期为`frame_type`的结束frame
+ """
+ if isinstance(day, str):
+ day = cls.date2int(arrow.get(day).date())
+ elif isinstance(day, arrow.Arrow) or isinstance(day, datetime.datetime):
+ day = cls.date2int(day.date())
+ elif isinstance(day, datetime.date):
+ day = cls.date2int(day)
+ else:
+ raise TypeError(f"{type(day)} is not supported.")
+
+ if frame_type in cls.minute_level_frames:
+ last_close_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(last_close_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=15, minute=0)
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported")
+
+ @classmethod
+ def frame_len(cls, frame_type: FrameType) -> int:
+ """返回以分钟为单位的frame长度。
+
+ 对日线以上级别没有意义,但会返回240
+
+ Examples:
+ >>> TimeFrame.frame_len(FrameType.MIN5)
+ 5
+
+ Args:
+ frame_type:
+
+ Returns:
+ 返回以分钟为单位的frame长度。
+
+ """
+
+ if frame_type == FrameType.MIN1:
+ return 1
+ elif frame_type == FrameType.MIN5:
+ return 5
+ elif frame_type == FrameType.MIN15:
+ return 15
+ elif frame_type == FrameType.MIN30:
+ return 30
+ elif frame_type == FrameType.MIN60:
+ return 60
+ else:
+ return 240
+
+ @classmethod
+ def first_min_frame(
+ cls, day: Union[str, Arrow, Frame], frame_type: FrameType
+ ) -> Union[datetime.date, datetime.datetime]:
+ """获取指定日期类型为`frame_type`的`frame`。
+
+ Examples:
+ >>> TimeFrame.day_frames = np.array([20191227, 20191230, 20191231, 20200102, 20200103])
+ >>> TimeFrame.first_min_frame('2019-12-31', FrameType.MIN1)
+ datetime.datetime(2019, 12, 31, 9, 31)
+
+ Args:
+ day: which day?
+ frame_type: which frame_type?
+
+ Returns:
+ `day`当日的第一帧
+ """
+ day = cls.date2int(arrow.get(day).date())
+
+ if frame_type == FrameType.MIN1:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=31)
+ elif frame_type == FrameType.MIN5:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=35)
+ elif frame_type == FrameType.MIN15:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=45)
+ elif frame_type == FrameType.MIN30:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=10)
+ elif frame_type == FrameType.MIN60:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=10, minute=30)
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported")
+
+ @classmethod
+ def get_frames(cls, start: Frame, end: Frame, frame_type: FrameType) -> List[int]:
+ """取[start, end]间所有类型为frame_type的frames
+
+ 调用本函数前,请先通过`floor`或者`ceiling`将时间帧对齐到`frame_type`的边界值
+
+ Example:
+ >>> start = arrow.get('2020-1-13 10:00').naive
+ >>> end = arrow.get('2020-1-13 13:30').naive
+ >>> TimeFrame.day_frames = np.array([20200109, 20200110, 20200113,20200114, 20200115, 20200116])
+ >>> TimeFrame.get_frames(start, end, FrameType.MIN30)
+ [202001131000, 202001131030, 202001131100, 202001131130, 202001131330]
+
+ Args:
+ start:
+ end:
+ frame_type:
+
+ Returns:
+ frame list
+ """
+ n = cls.count_frames(start, end, frame_type)
+ return cls.get_frames_by_count(end, n, frame_type)
+
+ @classmethod
+ def get_frames_by_count(
+ cls, end: Arrow, n: int, frame_type: FrameType
+ ) -> List[int]:
+ """取以end为结束点,周期为frame_type的n个frame
+
+ 调用前请将`end`对齐到`frame_type`的边界
+
+ Examples:
+ >>> end = arrow.get('2020-1-6 14:30').naive
+ >>> TimeFrame.day_frames = np.array([20200102, 20200103,20200106, 20200107, 20200108, 20200109])
+ >>> TimeFrame.get_frames_by_count(end, 2, FrameType.MIN30)
+ [202001061400, 202001061430]
+
+ Args:
+ end:
+ n:
+ frame_type:
+
+ Returns:
+ frame list
+ """
+
+ if frame_type == FrameType.DAY:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.day_frames, end, side="right")
+ return cls.day_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type == FrameType.WEEK:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.week_frames, end, side="right")
+ return cls.week_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type == FrameType.MONTH:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.month_frames, end, side="right")
+ return cls.month_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type in {
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ }:
+ n_days = n // len(cls.ticks[frame_type]) + 2
+ ticks = cls.ticks[frame_type] * n_days
+
+ days = cls.get_frames_by_count(end, n_days, FrameType.DAY)
+ days = np.repeat(days, len(cls.ticks[frame_type]))
+
+ ticks = [
+ day.item() * 10000 + int(tm / 60) * 100 + tm % 60
+ for day, tm in zip(days, ticks)
+ ]
+
+ # list index is much faster than ext.index_sorted when the arr is small
+ pos = ticks.index(cls.time2int(end)) + 1
+
+ return ticks[max(0, pos - n) : pos]
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not support yet")
+
+ @classmethod
+ def ceiling(cls, moment: Frame, frame_type: FrameType) -> Frame:
+ """求`moment`所在类型为`frame_type`周期的上界
+
+ 比如`moment`为14:59分,如果`frame_type`为30分钟,则它的上界应该为15:00
+
+ Example:
+ >>> TimeFrame.day_frames = [20050104, 20050105, 20050106, 20050107]
+ >>> TimeFrame.ceiling(datetime.date(2005, 1, 7), FrameType.DAY)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.week_frames = [20050107, 20050114, 20050121, 20050128]
+ >>> TimeFrame.ceiling(datetime.date(2005, 1, 4), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.ceiling(datetime.date(2005,1,7), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.month_frames = [20050131, 20050228]
+ >>> TimeFrame.ceiling(datetime.date(2005,1 ,1), FrameType.MONTH)
+ datetime.date(2005, 1, 31)
+
+ >>> TimeFrame.ceiling(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+ datetime.datetime(2005, 1, 5, 15, 0)
+
+ >>> TimeFrame.ceiling(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ >>> TimeFrame.ceiling(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ Args:
+ moment (datetime.datetime): [description]
+ frame_type (FrameType): [description]
+
+ Returns:
+ `moment`所在类型为`frame_type`周期的上界
+ """
+ if frame_type in cls.day_level_frames and type(moment) == datetime.datetime:
+ moment = moment.date()
+
+ floor = cls.floor(moment, frame_type)
+ if floor == moment:
+ return moment
+ elif floor > moment:
+ return floor
+ else:
+ return cls.shift(floor, 1, frame_type)
+
+ @classmethod
+ def combine_time(
+ cls,
+ date: datetime.date,
+ hour: int,
+ minute: int = 0,
+ second: int = 0,
+ microsecond: int = 0,
+ ) -> datetime.datetime:
+ """用`date`指定的日期与`hour`, `minute`, `second`等参数一起合成新的时间
+
+ Examples:
+ >>> TimeFrame.combine_time(datetime.date(2020, 1, 1), 14, 30)
+ datetime.datetime(2020, 1, 1, 14, 30)
+
+ Args:
+ date : [description]
+ hour : [description]
+ minute : [description]. Defaults to 0.
+ second : [description]. Defaults to 0.
+ microsecond : [description]. Defaults to 0.
+
+ Returns:
+ 合成后的时间
+ """
+ return datetime.datetime(
+ date.year, date.month, date.day, hour, minute, second, microsecond
+ )
+
+ @classmethod
+ def replace_date(
+ cls, dtm: datetime.datetime, dt: datetime.date
+ ) -> datetime.datetime:
+ """将`dtm`变量的日期更换为`dt`指定的日期
+
+ Example:
+ >>> TimeFrame.replace_date(arrow.get('2020-1-1 13:49').datetime, datetime.date(2019, 1,1))
+ datetime.datetime(2019, 1, 1, 13, 49)
+
+ Args:
+ dtm (datetime.datetime): [description]
+ dt (datetime.date): [description]
+
+ Returns:
+ 变换后的时间
+ """
+ return datetime.datetime(
+ dt.year, dt.month, dt.day, dtm.hour, dtm.minute, dtm.second, dtm.microsecond
+ )
+
+ @classmethod
+ def resample_frames(
+ cls, trade_days: Iterable[datetime.date], frame_type: FrameType
+ ) -> List[int]:
+ """将从行情服务器获取的交易日历重采样,生成周帧和月线帧
+
+ Args:
+ trade_days (Iterable): [description]
+ frame_type (FrameType): [description]
+
+ Returns:
+ List[int]: 重采样后的日期列表,日期用整数表示
+ """
+ if frame_type == FrameType.WEEK:
+ weeks = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.weekday() < last.weekday() or (cur - last).days >= 7:
+ weeks.append(last)
+ last = cur
+
+ if weeks[-1] < last:
+ weeks.append(last)
+
+ return weeks
+ elif frame_type == FrameType.MONTH:
+ months = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.day < last.day:
+ months.append(last)
+ last = cur
+ months.append(last)
+
+ return months
+ elif frame_type == FrameType.QUARTER:
+ quarters = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if last.month % 3 == 0:
+ if cur.month > last.month or cur.year > last.year:
+ quarters.append(last)
+ last = cur
+ quarters.append(last)
+
+ return quarters
+ elif frame_type == FrameType.YEAR:
+ years = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.year > last.year:
+ years.append(last)
+ last = cur
+ years.append(last)
+
+ return years
+ else: # pragma: no cover
+ raise ValueError(f"Unsupported FrameType: {frame_type}")
+
+ @classmethod
+ def minute_frames_floor(cls, ticks, moment) -> Tuple[int, int]:
+ """
+ 对于分钟级的frame,返回它们与frame刻度向下对齐后的frame及日期进位。如果需要对齐到上一个交易
+ 日,则进位为-1,否则为0.
+
+ Examples:
+ >>> ticks = [600, 630, 660, 690, 810, 840, 870, 900]
+ >>> TimeFrame.minute_frames_floor(ticks, 545)
+ (900, -1)
+ >>> TimeFrame.minute_frames_floor(ticks, 600)
+ (600, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 605)
+ (600, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 899)
+ (870, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 900)
+ (900, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 905)
+ (900, 0)
+
+ Args:
+ ticks (np.array or list): frames刻度
+ moment (int): 整数表示的分钟数,比如900表示15:00
+
+ Returns:
+ tuple, the first is the new moment, the second is carry-on
+ """
+ if moment < ticks[0]:
+ return ticks[-1], -1
+ # ’right' 相当于 ticks <= m
+ index = np.searchsorted(ticks, moment, side="right")
+ return ticks[index - 1], 0
+
+ @classmethod
+ async def save_calendar(cls, trade_days):
+ # avoid circular import
+ from omicron import cache
+
+ for ft in [FrameType.WEEK, FrameType.MONTH, FrameType.QUARTER, FrameType.YEAR]:
+ days = cls.resample_frames(trade_days, ft)
+ frames = [cls.date2int(x) for x in days]
+
+ key = f"calendar:{ft.value}"
+ pl = cache.security.pipeline()
+ pl.delete(key)
+ pl.rpush(key, *frames)
+ await pl.execute()
+
+ frames = [cls.date2int(x) for x in trade_days]
+ key = f"calendar:{FrameType.DAY.value}"
+ pl = cache.security.pipeline()
+ pl.delete(key)
+ pl.rpush(key, *frames)
+ await pl.execute()
+
+ @classmethod
+ async def remove_calendar(cls):
+ # avoid circular import
+ from omicron import cache
+
+ for ft in cls.day_level_frames:
+ key = f"calendar:{ft.value}"
+ await cache.security.delete(key)
+
+ @classmethod
+ def is_bar_closed(cls, frame: Frame, ft: FrameType) -> bool:
+ """判断`frame`所代表的bar是否已经收盘(结束)
+
+ 如果是日线,frame不为当天,则认为已收盘;或者当前时间在收盘时间之后,也认为已收盘。
+ 如果是其它周期,则只有当frame正好在边界上,才认为是已收盘。这里有一个假设:我们不会在其它周期上,判断未来的某个frame是否已经收盘。
+
+ Args:
+ frame : bar所处的时间,必须小于当前时间
+ ft: bar所代表的帧类型
+
+ Returns:
+ bool: 是否已经收盘
+ """
+ floor = cls.floor(frame, ft)
+
+ now = arrow.now()
+ if ft == FrameType.DAY:
+ return floor < now.date() or now.hour >= 15
+ else:
+ return floor == frame
+
+ @classmethod
+ def get_frame_scope(cls, frame: Frame, ft: FrameType) -> Tuple[Frame, Frame]:
+ # todo: 函数的通用性不足,似乎应该放在具体的业务类中。如果是通用型的函数,参数不应该局限于周和月。
+ """对于给定的时间,取所在周的第一天和最后一天,所在月的第一天和最后一天
+
+ Args:
+ frame : 指定的日期,date对象
+ ft: 帧类型,支持WEEK和MONTH
+
+ Returns:
+ Tuple[Frame, Frame]: 周或者月的首末日期(date对象)
+
+ """
+ if frame is None:
+ raise ValueError("frame cannot be None")
+ if ft not in (FrameType.WEEK, FrameType.MONTH):
+ raise ValueError(f"FrameType only supports WEEK and MONTH: {ft}")
+
+ if isinstance(frame, datetime.datetime):
+ frame = frame.date()
+
+ if frame < CALENDAR_START:
+ raise ValueError(f"cannot be earlier than {CALENDAR_START}: {frame}")
+
+ # datetime.date(2021, 10, 8),这是个特殊的日期
+ if ft == FrameType.WEEK:
+ if frame < datetime.date(2005, 1, 10):
+ return datetime.date(2005, 1, 4), datetime.date(2005, 1, 7)
+
+ if not cls.is_trade_day(frame): # 非交易日的情况,直接回退一天
+ week_day = cls.day_shift(frame, 0)
+ else:
+ week_day = frame
+
+ w1 = TimeFrame.floor(week_day, FrameType.WEEK)
+ if w1 == week_day: # 本周的最后一个交易日
+ week_end = w1
+ else:
+ week_end = TimeFrame.week_shift(week_day, 1)
+
+ w0 = TimeFrame.week_shift(week_end, -1)
+ week_start = TimeFrame.day_shift(w0, 1)
+ return week_start, week_end
+
+ if ft == FrameType.MONTH:
+ if frame <= datetime.date(2005, 1, 31):
+ return datetime.date(2005, 1, 4), datetime.date(2005, 1, 31)
+
+ month_start = frame.replace(day=1)
+ if not cls.is_trade_day(month_start): # 非交易日的情况,直接加1
+ month_start = cls.day_shift(month_start, 1)
+
+ month_end = TimeFrame.month_shift(month_start, 1)
+ return month_start, month_end
+
+ @classmethod
+ def get_previous_trade_day(cls, now: datetime.date):
+ """获取上一个交易日
+
+ 如果当天是周六或者周日,返回周五(交易日),如果当天是周一,返回周五,如果当天是周五,返回周四
+
+ Args:
+ now : 指定的日期,date对象
+
+ Returns:
+ datetime.date: 上一个交易日
+
+ """
+ if now == datetime.date(2005, 1, 4):
+ return now
+
+ if TimeFrame.is_trade_day(now):
+ pre_trade_day = TimeFrame.day_shift(now, -1)
+ else:
+ pre_trade_day = TimeFrame.day_shift(now, 0)
+ return pre_trade_day
+
ceiling(moment, frame_type)
+
+
+ classmethod
+
+
+¶求moment
所在类型为frame_type
周期的上界
比如moment
为14:59分,如果frame_type
为30分钟,则它的上界应该为15:00
Examples:
+>>> TimeFrame.day_frames = [20050104, 20050105, 20050106, 20050107]
+>>> TimeFrame.ceiling(datetime.date(2005, 1, 7), FrameType.DAY)
+datetime.date(2005, 1, 7)
+
>>> TimeFrame.week_frames = [20050107, 20050114, 20050121, 20050128]
+>>> TimeFrame.ceiling(datetime.date(2005, 1, 4), FrameType.WEEK)
+datetime.date(2005, 1, 7)
+
>>> TimeFrame.ceiling(datetime.date(2005,1,7), FrameType.WEEK)
+datetime.date(2005, 1, 7)
+
>>> TimeFrame.month_frames = [20050131, 20050228]
+>>> TimeFrame.ceiling(datetime.date(2005,1 ,1), FrameType.MONTH)
+datetime.date(2005, 1, 31)
+
>>> TimeFrame.ceiling(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+datetime.datetime(2005, 1, 5, 15, 0)
+
>>> TimeFrame.ceiling(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+datetime.datetime(2005, 1, 5, 14, 59)
+
>>> TimeFrame.ceiling(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+datetime.datetime(2005, 1, 5, 14, 59)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
moment |
+ datetime.datetime |
+ [description] |
+ required | +
frame_type |
+ FrameType |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
Frame |
+
|
+
omicron/models/timeframe.py
@classmethod
+def ceiling(cls, moment: Frame, frame_type: FrameType) -> Frame:
+ """求`moment`所在类型为`frame_type`周期的上界
+
+ 比如`moment`为14:59分,如果`frame_type`为30分钟,则它的上界应该为15:00
+
+ Example:
+ >>> TimeFrame.day_frames = [20050104, 20050105, 20050106, 20050107]
+ >>> TimeFrame.ceiling(datetime.date(2005, 1, 7), FrameType.DAY)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.week_frames = [20050107, 20050114, 20050121, 20050128]
+ >>> TimeFrame.ceiling(datetime.date(2005, 1, 4), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.ceiling(datetime.date(2005,1,7), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.month_frames = [20050131, 20050228]
+ >>> TimeFrame.ceiling(datetime.date(2005,1 ,1), FrameType.MONTH)
+ datetime.date(2005, 1, 31)
+
+ >>> TimeFrame.ceiling(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+ datetime.datetime(2005, 1, 5, 15, 0)
+
+ >>> TimeFrame.ceiling(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ >>> TimeFrame.ceiling(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ Args:
+ moment (datetime.datetime): [description]
+ frame_type (FrameType): [description]
+
+ Returns:
+ `moment`所在类型为`frame_type`周期的上界
+ """
+ if frame_type in cls.day_level_frames and type(moment) == datetime.datetime:
+ moment = moment.date()
+
+ floor = cls.floor(moment, frame_type)
+ if floor == moment:
+ return moment
+ elif floor > moment:
+ return floor
+ else:
+ return cls.shift(floor, 1, frame_type)
+
combine_time(date, hour, minute=0, second=0, microsecond=0)
+
+
+ classmethod
+
+
+¶用date
指定的日期与hour
, minute
, second
等参数一起合成新的时间
Examples:
+>>> TimeFrame.combine_time(datetime.date(2020, 1, 1), 14, 30)
+datetime.datetime(2020, 1, 1, 14, 30)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
date |
+ + | [description] |
+ required | +
hour |
+ + | [description] |
+ required | +
minute |
+ + | [description]. Defaults to 0. |
+ 0 |
+
second |
+ + | [description]. Defaults to 0. |
+ 0 |
+
microsecond |
+ + | [description]. Defaults to 0. |
+ 0 |
+
Returns:
+Type | +Description | +
---|---|
datetime.datetime |
+ 合成后的时间 |
+
omicron/models/timeframe.py
@classmethod
+def combine_time(
+ cls,
+ date: datetime.date,
+ hour: int,
+ minute: int = 0,
+ second: int = 0,
+ microsecond: int = 0,
+) -> datetime.datetime:
+ """用`date`指定的日期与`hour`, `minute`, `second`等参数一起合成新的时间
+
+ Examples:
+ >>> TimeFrame.combine_time(datetime.date(2020, 1, 1), 14, 30)
+ datetime.datetime(2020, 1, 1, 14, 30)
+
+ Args:
+ date : [description]
+ hour : [description]
+ minute : [description]. Defaults to 0.
+ second : [description]. Defaults to 0.
+ microsecond : [description]. Defaults to 0.
+
+ Returns:
+ 合成后的时间
+ """
+ return datetime.datetime(
+ date.year, date.month, date.day, hour, minute, second, microsecond
+ )
+
count_day_frames(start, end)
+
+
+ classmethod
+
+
+¶calc trade days between start and end in close-to-close way.
+if start == end, this will returns 1. Both start/end will be aligned to open +trade day before calculation.
+ +Examples:
+>>> start = datetime.date(2019, 12, 21)
+>>> end = datetime.date(2019, 12, 21)
+>>> TimeFrame.day_frames = [20191219, 20191220, 20191223, 20191224, 20191225]
+>>> TimeFrame.count_day_frames(start, end)
+1
+
>>> # non-trade days are removed
+>>> TimeFrame.day_frames = [20200121, 20200122, 20200123, 20200203, 20200204, 20200205]
+>>> start = datetime.date(2020, 1, 23)
+>>> end = datetime.date(2020, 2, 4)
+>>> TimeFrame.count_day_frames(start, end)
+3
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ Union[datetime.date, Arrow] |
+ + | required | +
end |
+ Union[datetime.date, Arrow] |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
int |
+ count of days |
+
omicron/models/timeframe.py
@classmethod
+def count_day_frames(
+ cls, start: Union[datetime.date, Arrow], end: Union[datetime.date, Arrow]
+) -> int:
+ """calc trade days between start and end in close-to-close way.
+
+ if start == end, this will returns 1. Both start/end will be aligned to open
+ trade day before calculation.
+
+ Examples:
+ >>> start = datetime.date(2019, 12, 21)
+ >>> end = datetime.date(2019, 12, 21)
+ >>> TimeFrame.day_frames = [20191219, 20191220, 20191223, 20191224, 20191225]
+ >>> TimeFrame.count_day_frames(start, end)
+ 1
+
+ >>> # non-trade days are removed
+ >>> TimeFrame.day_frames = [20200121, 20200122, 20200123, 20200203, 20200204, 20200205]
+ >>> start = datetime.date(2020, 1, 23)
+ >>> end = datetime.date(2020, 2, 4)
+ >>> TimeFrame.count_day_frames(start, end)
+ 3
+
+ args:
+ start:
+ end:
+ returns:
+ count of days
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+ return int(ext.count_between(cls.day_frames, start, end))
+
count_frames(start, end, frame_type)
+
+
+ classmethod
+
+
+¶计算start与end之间有多少个周期为frame_type的frames
+See also:
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ + | start frame |
+ required | +
end |
+ + | end frame |
+ required | +
frame_type |
+ + | the type of frame |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
ValueError |
+ 如果frame_type不支持,则会抛出此异常。 |
+
Returns:
+Type | +Description | +
---|---|
int |
+ 从start到end的帧数 |
+
omicron/models/timeframe.py
@classmethod
+def count_frames(
+ cls,
+ start: Union[datetime.date, datetime.datetime, Arrow],
+ end: Union[datetime.date, datetime.datetime, Arrow],
+ frame_type,
+) -> int:
+ """计算start与end之间有多少个周期为frame_type的frames
+
+ See also:
+
+ - [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+ - [count_week_frames][omicron.models.timeframe.TimeFrame.count_week_frames]
+ - [count_month_frames][omicron.models.timeframe.TimeFrame.count_month_frames]
+
+ Args:
+ start : start frame
+ end : end frame
+ frame_type : the type of frame
+
+ Raises:
+ ValueError: 如果frame_type不支持,则会抛出此异常。
+
+ Returns:
+ 从start到end的帧数
+ """
+ if frame_type == FrameType.DAY:
+ return cls.count_day_frames(start, end)
+ elif frame_type == FrameType.WEEK:
+ return cls.count_week_frames(start, end)
+ elif frame_type == FrameType.MONTH:
+ return cls.count_month_frames(start, end)
+ elif frame_type == FrameType.QUARTER:
+ return cls.count_quarter_frames(start, end)
+ elif frame_type == FrameType.YEAR:
+ return cls.count_year_frames(start, end)
+ elif frame_type in [
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ ]:
+ tm_start = start.hour * 60 + start.minute
+ tm_end = end.hour * 60 + end.minute
+ days = cls.count_day_frames(start.date(), end.date()) - 1
+
+ tm_start_pos = cls.ticks[frame_type].index(tm_start)
+ tm_end_pos = cls.ticks[frame_type].index(tm_end)
+
+ min_bars = tm_end_pos - tm_start_pos + 1
+
+ return days * len(cls.ticks[frame_type]) + min_bars
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} is not supported yet")
+
count_month_frames(start, end)
+
+
+ classmethod
+
+
+¶calc trade months between start and end date in close-to-close way +Both start and end will be aligned to open trade day before calculation. After +that, if start == end, this will returns 1.
+For examples, please refer to count_day_frames
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ datetime.date |
+ + | required | +
end |
+ datetime.date |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
int |
+ months between start and end |
+
omicron/models/timeframe.py
@classmethod
+def count_month_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade months between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start:
+ end:
+
+ Returns:
+ months between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.month_frames, start, end))
+
count_quarter_frames(start, end)
+
+
+ classmethod
+
+
+¶calc trade quarters between start and end date in close-to-close way +Both start and end will be aligned to open trade day before calculation. After +that, if start == end, this will returns 1.
+For examples, please refer to count_day_frames
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ datetime.date |
+ [description] |
+ required | +
end |
+ datetime.date |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
int |
+ quarters between start and end |
+
omicron/models/timeframe.py
@classmethod
+def count_quarter_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade quarters between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start (datetime.date): [description]
+ end (datetime.date): [description]
+
+ Returns:
+ quarters between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.quarter_frames, start, end))
+
count_week_frames(start, end)
+
+
+ classmethod
+
+
+¶calc trade weeks between start and end in close-to-close way. Both start and +end will be aligned to open trade day before calculation. After that, if start + == end, this will returns 1
+for examples, please refer to count_day_frames
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ datetime.date |
+ + | required | +
end |
+ datetime.date |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
int |
+ count of weeks |
+
omicron/models/timeframe.py
@classmethod
+def count_week_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """
+ calc trade weeks between start and end in close-to-close way. Both start and
+ end will be aligned to open trade day before calculation. After that, if start
+ == end, this will returns 1
+
+ for examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+ args:
+ start:
+ end:
+ returns:
+ count of weeks
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+ return int(ext.count_between(cls.week_frames, start, end))
+
count_year_frames(start, end)
+
+
+ classmethod
+
+
+¶calc trade years between start and end date in close-to-close way +Both start and end will be aligned to open trade day before calculation. After +that, if start == end, this will returns 1.
+For examples, please refer to count_day_frames
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ datetime.date |
+ [description] |
+ required | +
end |
+ datetime.date |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
int |
+ years between start and end |
+
omicron/models/timeframe.py
@classmethod
+def count_year_frames(cls, start: datetime.date, end: datetime.date) -> int:
+ """calc trade years between start and end date in close-to-close way
+ Both start and end will be aligned to open trade day before calculation. After
+ that, if start == end, this will returns 1.
+
+ For examples, please refer to [count_day_frames][omicron.models.timeframe.TimeFrame.count_day_frames]
+
+ Args:
+ start (datetime.date): [description]
+ end (datetime.date): [description]
+
+ Returns:
+ years between start and end
+ """
+ start = cls.date2int(start)
+ end = cls.date2int(end)
+
+ return int(ext.count_between(cls.year_frames, start, end))
+
date2int(d)
+
+
+ classmethod
+
+
+¶将日期转换为整数表示
+在zillionare中,如果要对时间和日期进行持久化操作,我们一般将其转换为int类型
+ +Examples:
+>>> TimeFrame.date2int(datetime.date(2020,5,1))
+20200501
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
d |
+ Union[datetime.datetime, datetime.date, Arrow] |
+ date |
+ required | +
Returns:
+Type | +Description | +
---|---|
int |
+ 日期的整数表示,比如20220211 |
+
omicron/models/timeframe.py
@classmethod
+def date2int(cls, d: Union[datetime.datetime, datetime.date, Arrow]) -> int:
+ """将日期转换为整数表示
+
+ 在zillionare中,如果要对时间和日期进行持久化操作,我们一般将其转换为int类型
+
+ Examples:
+ >>> TimeFrame.date2int(datetime.date(2020,5,1))
+ 20200501
+
+ Args:
+ d: date
+
+ Returns:
+ 日期的整数表示,比如20220211
+ """
+ return int(f"{d.year:04}{d.month:02}{d.day:02}")
+
day_shift(start, offset)
+
+
+ classmethod
+
+
+¶对指定日期进行前后移位操作
+如果 n == 0,则返回d对应的交易日(如果是非交易日,则返回刚结束的一个交易日) +如果 n > 0,则返回d对应的交易日后第 n 个交易日 +如果 n < 0,则返回d对应的交易日前第 n 个交易日
+ +Examples:
+>>> TimeFrame.day_frames = [20191212, 20191213, 20191216, 20191217,20191218, 20191219]
+>>> TimeFrame.day_shift(datetime.date(2019,12,13), 0)
+datetime.date(2019, 12, 13)
+
>>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 0)
+datetime.date(2019, 12, 13)
+
>>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 1)
+datetime.date(2019, 12, 16)
+
>>> TimeFrame.day_shift(datetime.date(2019, 12, 13), 1)
+datetime.date(2019, 12, 16)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ datetime.date |
+ the origin day |
+ required | +
offset |
+ int |
+ days to shift, can be negative |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.date |
+ 移位后的日期 |
+
omicron/models/timeframe.py
@classmethod
+def day_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """对指定日期进行前后移位操作
+
+ 如果 n == 0,则返回d对应的交易日(如果是非交易日,则返回刚结束的一个交易日)
+ 如果 n > 0,则返回d对应的交易日后第 n 个交易日
+ 如果 n < 0,则返回d对应的交易日前第 n 个交易日
+
+ Examples:
+ >>> TimeFrame.day_frames = [20191212, 20191213, 20191216, 20191217,20191218, 20191219]
+ >>> TimeFrame.day_shift(datetime.date(2019,12,13), 0)
+ datetime.date(2019, 12, 13)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 0)
+ datetime.date(2019, 12, 13)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 15), 1)
+ datetime.date(2019, 12, 16)
+
+ >>> TimeFrame.day_shift(datetime.date(2019, 12, 13), 1)
+ datetime.date(2019, 12, 16)
+
+ Args:
+ start: the origin day
+ offset: days to shift, can be negative
+
+ Returns:
+ 移位后的日期
+ """
+ # accelerated from 0.12 to 0.07, per 10000 loop, type conversion time included
+ start = cls.date2int(start)
+
+ return cls.int2date(ext.shift(cls.day_frames, start, offset))
+
first_min_frame(day, frame_type)
+
+
+ classmethod
+
+
+¶获取指定日期类型为frame_type
的frame
。
Examples:
+>>> TimeFrame.day_frames = np.array([20191227, 20191230, 20191231, 20200102, 20200103])
+>>> TimeFrame.first_min_frame('2019-12-31', FrameType.MIN1)
+datetime.datetime(2019, 12, 31, 9, 31)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
day |
+ Union[str, Arrow, Frame] |
+ which day? |
+ required | +
frame_type |
+ FrameType |
+ which frame_type? |
+ required | +
Returns:
+Type | +Description | +
---|---|
Union[datetime.date, datetime.datetime] |
+
|
+
omicron/models/timeframe.py
@classmethod
+def first_min_frame(
+ cls, day: Union[str, Arrow, Frame], frame_type: FrameType
+) -> Union[datetime.date, datetime.datetime]:
+ """获取指定日期类型为`frame_type`的`frame`。
+
+ Examples:
+ >>> TimeFrame.day_frames = np.array([20191227, 20191230, 20191231, 20200102, 20200103])
+ >>> TimeFrame.first_min_frame('2019-12-31', FrameType.MIN1)
+ datetime.datetime(2019, 12, 31, 9, 31)
+
+ Args:
+ day: which day?
+ frame_type: which frame_type?
+
+ Returns:
+ `day`当日的第一帧
+ """
+ day = cls.date2int(arrow.get(day).date())
+
+ if frame_type == FrameType.MIN1:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=31)
+ elif frame_type == FrameType.MIN5:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=35)
+ elif frame_type == FrameType.MIN15:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=9, minute=45)
+ elif frame_type == FrameType.MIN30:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=10)
+ elif frame_type == FrameType.MIN60:
+ floor_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(floor_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=10, minute=30)
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported")
+
floor(moment, frame_type)
+
+
+ classmethod
+
+
+¶求moment
在指定的frame_type
中的下界
比如,如果moment
为10:37,则当frame_type
为30分钟时,对应的上界为10:00
Examples:
+>>> # 如果moment为日期,则当成已收盘处理
+>>> TimeFrame.day_frames = np.array([20050104, 20050105, 20050106, 20050107, 20050110, 20050111])
+>>> TimeFrame.floor(datetime.date(2005, 1, 7), FrameType.DAY)
+datetime.date(2005, 1, 7)
+
>>> # moment指定的时间还未收盘,floor到上一个交易日
+>>> TimeFrame.floor(datetime.datetime(2005, 1, 7, 14, 59), FrameType.DAY)
+datetime.date(2005, 1, 6)
+
>>> TimeFrame.floor(datetime.date(2005, 1, 13), FrameType.WEEK)
+datetime.date(2005, 1, 7)
+
>>> TimeFrame.floor(datetime.date(2005,2, 27), FrameType.MONTH)
+datetime.date(2005, 1, 31)
+
>>> TimeFrame.floor(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+datetime.datetime(2005, 1, 5, 14, 30)
+
>>> TimeFrame.floor(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+datetime.datetime(2005, 1, 5, 14, 59)
+
>>> TimeFrame.floor(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+datetime.datetime(2005, 1, 5, 14, 59)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
moment |
+ Frame |
+ + | required | +
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
Frame |
+
|
+
omicron/models/timeframe.py
@classmethod
+def floor(cls, moment: Frame, frame_type: FrameType) -> Frame:
+ """求`moment`在指定的`frame_type`中的下界
+
+ 比如,如果`moment`为10:37,则当`frame_type`为30分钟时,对应的上界为10:00
+
+ Examples:
+ >>> # 如果moment为日期,则当成已收盘处理
+ >>> TimeFrame.day_frames = np.array([20050104, 20050105, 20050106, 20050107, 20050110, 20050111])
+ >>> TimeFrame.floor(datetime.date(2005, 1, 7), FrameType.DAY)
+ datetime.date(2005, 1, 7)
+
+ >>> # moment指定的时间还未收盘,floor到上一个交易日
+ >>> TimeFrame.floor(datetime.datetime(2005, 1, 7, 14, 59), FrameType.DAY)
+ datetime.date(2005, 1, 6)
+
+ >>> TimeFrame.floor(datetime.date(2005, 1, 13), FrameType.WEEK)
+ datetime.date(2005, 1, 7)
+
+ >>> TimeFrame.floor(datetime.date(2005,2, 27), FrameType.MONTH)
+ datetime.date(2005, 1, 31)
+
+ >>> TimeFrame.floor(datetime.datetime(2005,1,5,14,59), FrameType.MIN30)
+ datetime.datetime(2005, 1, 5, 14, 30)
+
+ >>> TimeFrame.floor(datetime.datetime(2005, 1, 5, 14, 59), FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ >>> TimeFrame.floor(arrow.get('2005-1-5 14:59').naive, FrameType.MIN1)
+ datetime.datetime(2005, 1, 5, 14, 59)
+
+ Args:
+ moment:
+ frame_type:
+
+ Returns:
+ `moment`在指定的`frame_type`中的下界
+ """
+ if frame_type in cls.minute_level_frames:
+ tm, day_offset = cls.minute_frames_floor(
+ cls.ticks[frame_type], moment.hour * 60 + moment.minute
+ )
+ h, m = tm // 60, tm % 60
+ if cls.day_shift(moment, 0) < moment.date() or day_offset == -1:
+ h = 15
+ m = 0
+ new_day = cls.day_shift(moment, day_offset)
+ else:
+ new_day = moment.date()
+ return datetime.datetime(new_day.year, new_day.month, new_day.day, h, m)
+
+ if type(moment) == datetime.date:
+ moment = datetime.datetime(moment.year, moment.month, moment.day, 15)
+
+ # 如果是交易日,但还未收盘
+ if (
+ cls.date2int(moment) in cls.day_frames
+ and moment.hour * 60 + moment.minute < 900
+ ):
+ moment = cls.day_shift(moment, -1)
+
+ day = cls.date2int(moment)
+ if frame_type == FrameType.DAY:
+ arr = cls.day_frames
+ elif frame_type == FrameType.WEEK:
+ arr = cls.week_frames
+ elif frame_type == FrameType.MONTH:
+ arr = cls.month_frames
+ else: # pragma: no cover
+ raise ValueError(f"frame type {frame_type} not supported.")
+
+ floored = ext.floor(arr, day)
+ return cls.int2date(floored)
+
frame_len(frame_type)
+
+
+ classmethod
+
+
+¶返回以分钟为单位的frame长度。
+对日线以上级别没有意义,但会返回240
+ +Examples:
+>>> TimeFrame.frame_len(FrameType.MIN5)
+5
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
int |
+ 返回以分钟为单位的frame长度。 |
+
omicron/models/timeframe.py
@classmethod
+def frame_len(cls, frame_type: FrameType) -> int:
+ """返回以分钟为单位的frame长度。
+
+ 对日线以上级别没有意义,但会返回240
+
+ Examples:
+ >>> TimeFrame.frame_len(FrameType.MIN5)
+ 5
+
+ Args:
+ frame_type:
+
+ Returns:
+ 返回以分钟为单位的frame长度。
+
+ """
+
+ if frame_type == FrameType.MIN1:
+ return 1
+ elif frame_type == FrameType.MIN5:
+ return 5
+ elif frame_type == FrameType.MIN15:
+ return 15
+ elif frame_type == FrameType.MIN30:
+ return 30
+ elif frame_type == FrameType.MIN60:
+ return 60
+ else:
+ return 240
+
get_frame_scope(frame, ft)
+
+
+ classmethod
+
+
+¶对于给定的时间,取所在周的第一天和最后一天,所在月的第一天和最后一天
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame |
+ + | 指定的日期,date对象 |
+ required | +
ft |
+ FrameType |
+ 帧类型,支持WEEK和MONTH |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple[Frame, Frame] |
+ 周或者月的首末日期(date对象) |
+
omicron/models/timeframe.py
@classmethod
+def get_frame_scope(cls, frame: Frame, ft: FrameType) -> Tuple[Frame, Frame]:
+ # todo: 函数的通用性不足,似乎应该放在具体的业务类中。如果是通用型的函数,参数不应该局限于周和月。
+ """对于给定的时间,取所在周的第一天和最后一天,所在月的第一天和最后一天
+
+ Args:
+ frame : 指定的日期,date对象
+ ft: 帧类型,支持WEEK和MONTH
+
+ Returns:
+ Tuple[Frame, Frame]: 周或者月的首末日期(date对象)
+
+ """
+ if frame is None:
+ raise ValueError("frame cannot be None")
+ if ft not in (FrameType.WEEK, FrameType.MONTH):
+ raise ValueError(f"FrameType only supports WEEK and MONTH: {ft}")
+
+ if isinstance(frame, datetime.datetime):
+ frame = frame.date()
+
+ if frame < CALENDAR_START:
+ raise ValueError(f"cannot be earlier than {CALENDAR_START}: {frame}")
+
+ # datetime.date(2021, 10, 8),这是个特殊的日期
+ if ft == FrameType.WEEK:
+ if frame < datetime.date(2005, 1, 10):
+ return datetime.date(2005, 1, 4), datetime.date(2005, 1, 7)
+
+ if not cls.is_trade_day(frame): # 非交易日的情况,直接回退一天
+ week_day = cls.day_shift(frame, 0)
+ else:
+ week_day = frame
+
+ w1 = TimeFrame.floor(week_day, FrameType.WEEK)
+ if w1 == week_day: # 本周的最后一个交易日
+ week_end = w1
+ else:
+ week_end = TimeFrame.week_shift(week_day, 1)
+
+ w0 = TimeFrame.week_shift(week_end, -1)
+ week_start = TimeFrame.day_shift(w0, 1)
+ return week_start, week_end
+
+ if ft == FrameType.MONTH:
+ if frame <= datetime.date(2005, 1, 31):
+ return datetime.date(2005, 1, 4), datetime.date(2005, 1, 31)
+
+ month_start = frame.replace(day=1)
+ if not cls.is_trade_day(month_start): # 非交易日的情况,直接加1
+ month_start = cls.day_shift(month_start, 1)
+
+ month_end = TimeFrame.month_shift(month_start, 1)
+ return month_start, month_end
+
get_frames(start, end, frame_type)
+
+
+ classmethod
+
+
+¶取[start, end]间所有类型为frame_type的frames
+调用本函数前,请先通过floor
或者ceiling
将时间帧对齐到frame_type
的边界值
Examples:
+>>> start = arrow.get('2020-1-13 10:00').naive
+>>> end = arrow.get('2020-1-13 13:30').naive
+>>> TimeFrame.day_frames = np.array([20200109, 20200110, 20200113,20200114, 20200115, 20200116])
+>>> TimeFrame.get_frames(start, end, FrameType.MIN30)
+[202001131000, 202001131030, 202001131100, 202001131130, 202001131330]
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
start |
+ Frame |
+ + | required | +
end |
+ Frame |
+ + | required | +
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
List[int] |
+ frame list |
+
omicron/models/timeframe.py
@classmethod
+def get_frames(cls, start: Frame, end: Frame, frame_type: FrameType) -> List[int]:
+ """取[start, end]间所有类型为frame_type的frames
+
+ 调用本函数前,请先通过`floor`或者`ceiling`将时间帧对齐到`frame_type`的边界值
+
+ Example:
+ >>> start = arrow.get('2020-1-13 10:00').naive
+ >>> end = arrow.get('2020-1-13 13:30').naive
+ >>> TimeFrame.day_frames = np.array([20200109, 20200110, 20200113,20200114, 20200115, 20200116])
+ >>> TimeFrame.get_frames(start, end, FrameType.MIN30)
+ [202001131000, 202001131030, 202001131100, 202001131130, 202001131330]
+
+ Args:
+ start:
+ end:
+ frame_type:
+
+ Returns:
+ frame list
+ """
+ n = cls.count_frames(start, end, frame_type)
+ return cls.get_frames_by_count(end, n, frame_type)
+
get_frames_by_count(end, n, frame_type)
+
+
+ classmethod
+
+
+¶取以end为结束点,周期为frame_type的n个frame
+调用前请将end
对齐到frame_type
的边界
Examples:
+>>> end = arrow.get('2020-1-6 14:30').naive
+>>> TimeFrame.day_frames = np.array([20200102, 20200103,20200106, 20200107, 20200108, 20200109])
+>>> TimeFrame.get_frames_by_count(end, 2, FrameType.MIN30)
+[202001061400, 202001061430]
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
end |
+ Arrow |
+ + | required | +
n |
+ int |
+ + | required | +
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
List[int] |
+ frame list |
+
omicron/models/timeframe.py
@classmethod
+def get_frames_by_count(
+ cls, end: Arrow, n: int, frame_type: FrameType
+) -> List[int]:
+ """取以end为结束点,周期为frame_type的n个frame
+
+ 调用前请将`end`对齐到`frame_type`的边界
+
+ Examples:
+ >>> end = arrow.get('2020-1-6 14:30').naive
+ >>> TimeFrame.day_frames = np.array([20200102, 20200103,20200106, 20200107, 20200108, 20200109])
+ >>> TimeFrame.get_frames_by_count(end, 2, FrameType.MIN30)
+ [202001061400, 202001061430]
+
+ Args:
+ end:
+ n:
+ frame_type:
+
+ Returns:
+ frame list
+ """
+
+ if frame_type == FrameType.DAY:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.day_frames, end, side="right")
+ return cls.day_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type == FrameType.WEEK:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.week_frames, end, side="right")
+ return cls.week_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type == FrameType.MONTH:
+ end = cls.date2int(end)
+ pos = np.searchsorted(cls.month_frames, end, side="right")
+ return cls.month_frames[max(0, pos - n) : pos].tolist()
+ elif frame_type in {
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ }:
+ n_days = n // len(cls.ticks[frame_type]) + 2
+ ticks = cls.ticks[frame_type] * n_days
+
+ days = cls.get_frames_by_count(end, n_days, FrameType.DAY)
+ days = np.repeat(days, len(cls.ticks[frame_type]))
+
+ ticks = [
+ day.item() * 10000 + int(tm / 60) * 100 + tm % 60
+ for day, tm in zip(days, ticks)
+ ]
+
+ # list index is much faster than ext.index_sorted when the arr is small
+ pos = ticks.index(cls.time2int(end)) + 1
+
+ return ticks[max(0, pos - n) : pos]
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not support yet")
+
get_previous_trade_day(now)
+
+
+ classmethod
+
+
+¶获取上一个交易日
+如果当天是周六或者周日,返回周五(交易日),如果当天是周一,返回周五,如果当天是周五,返回周四
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
now |
+ + | 指定的日期,date对象 |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.date |
+ 上一个交易日 |
+
omicron/models/timeframe.py
@classmethod
+def get_previous_trade_day(cls, now: datetime.date):
+ """获取上一个交易日
+
+ 如果当天是周六或者周日,返回周五(交易日),如果当天是周一,返回周五,如果当天是周五,返回周四
+
+ Args:
+ now : 指定的日期,date对象
+
+ Returns:
+ datetime.date: 上一个交易日
+
+ """
+ if now == datetime.date(2005, 1, 4):
+ return now
+
+ if TimeFrame.is_trade_day(now):
+ pre_trade_day = TimeFrame.day_shift(now, -1)
+ else:
+ pre_trade_day = TimeFrame.day_shift(now, 0)
+ return pre_trade_day
+
get_ticks(frame_type)
+
+
+ classmethod
+
+
+¶取月线、周线、日线及各分钟线对应的frame
+对分钟线,返回值仅包含时间,不包含日期(均为整数表示)
+ +Examples:
+>>> TimeFrame.month_frames = np.array([20050131, 20050228, 20050331])
+>>> TimeFrame.get_ticks(FrameType.MONTH)[:3]
+array([20050131, 20050228, 20050331])
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ + | [description] |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
ValueError |
+ [description] |
+
Returns:
+Type | +Description | +
---|---|
Union[List, np.array] |
+ 月线、周线、日线及各分钟线对应的frame |
+
omicron/models/timeframe.py
@classmethod
+def get_ticks(cls, frame_type: FrameType) -> Union[List, np.array]:
+ """取月线、周线、日线及各分钟线对应的frame
+
+ 对分钟线,返回值仅包含时间,不包含日期(均为整数表示)
+
+ Examples:
+ >>> TimeFrame.month_frames = np.array([20050131, 20050228, 20050331])
+ >>> TimeFrame.get_ticks(FrameType.MONTH)[:3]
+ array([20050131, 20050228, 20050331])
+
+ Args:
+ frame_type : [description]
+
+ Raises:
+ ValueError: [description]
+
+ Returns:
+ 月线、周线、日线及各分钟线对应的frame
+ """
+ if frame_type in cls.minute_level_frames:
+ return cls.ticks[frame_type]
+
+ if frame_type == FrameType.DAY:
+ return cls.day_frames
+ elif frame_type == FrameType.WEEK:
+ return cls.week_frames
+ elif frame_type == FrameType.MONTH:
+ return cls.month_frames
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported!")
+
init()
+
+
+ async
+ classmethod
+
+
+¶初始化日历
+ +omicron/models/timeframe.py
@classmethod
+async def init(cls):
+ """初始化日历"""
+ await cls._load_calendar()
+
int2date(d)
+
+
+ classmethod
+
+
+¶将数字表示的日期转换成为日期格式
+ +Examples:
+>>> TimeFrame.int2date(20200501)
+datetime.date(2020, 5, 1)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
d |
+ Union[int, str] |
+ YYYYMMDD表示的日期 |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.date |
+ 转换后的日期 |
+
omicron/models/timeframe.py
@classmethod
+def int2date(cls, d: Union[int, str]) -> datetime.date:
+ """将数字表示的日期转换成为日期格式
+
+ Examples:
+ >>> TimeFrame.int2date(20200501)
+ datetime.date(2020, 5, 1)
+
+ Args:
+ d: YYYYMMDD表示的日期
+
+ Returns:
+ 转换后的日期
+ """
+ s = str(d)
+ # it's 8 times faster than arrow.get
+ return datetime.date(int(s[:4]), int(s[4:6]), int(s[6:]))
+
int2time(tm)
+
+
+ classmethod
+
+
+¶将整数表示的时间转换为datetime
类型表示
Examples:
+>>> TimeFrame.int2time(202005011500)
+datetime.datetime(2020, 5, 1, 15, 0)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ int |
+ time in YYYYMMDDHHmm format |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.datetime |
+ 转换后的时间 |
+
omicron/models/timeframe.py
@classmethod
+def int2time(cls, tm: int) -> datetime.datetime:
+ """将整数表示的时间转换为`datetime`类型表示
+
+ examples:
+ >>> TimeFrame.int2time(202005011500)
+ datetime.datetime(2020, 5, 1, 15, 0)
+
+ Args:
+ tm: time in YYYYMMDDHHmm format
+
+ Returns:
+ 转换后的时间
+ """
+ s = str(tm)
+ # its 8 times faster than arrow.get()
+ return datetime.datetime(
+ int(s[:4]), int(s[4:6]), int(s[6:8]), int(s[8:10]), int(s[10:12])
+ )
+
is_bar_closed(frame, ft)
+
+
+ classmethod
+
+
+¶判断frame
所代表的bar是否已经收盘(结束)
如果是日线,frame不为当天,则认为已收盘;或者当前时间在收盘时间之后,也认为已收盘。 +如果是其它周期,则只有当frame正好在边界上,才认为是已收盘。这里有一个假设:我们不会在其它周期上,判断未来的某个frame是否已经收盘。
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame |
+ + | bar所处的时间,必须小于当前时间 |
+ required | +
ft |
+ FrameType |
+ bar所代表的帧类型 |
+ required | +
Returns:
+Type | +Description | +
---|---|
bool |
+ 是否已经收盘 |
+
omicron/models/timeframe.py
@classmethod
+def is_bar_closed(cls, frame: Frame, ft: FrameType) -> bool:
+ """判断`frame`所代表的bar是否已经收盘(结束)
+
+ 如果是日线,frame不为当天,则认为已收盘;或者当前时间在收盘时间之后,也认为已收盘。
+ 如果是其它周期,则只有当frame正好在边界上,才认为是已收盘。这里有一个假设:我们不会在其它周期上,判断未来的某个frame是否已经收盘。
+
+ Args:
+ frame : bar所处的时间,必须小于当前时间
+ ft: bar所代表的帧类型
+
+ Returns:
+ bool: 是否已经收盘
+ """
+ floor = cls.floor(frame, ft)
+
+ now = arrow.now()
+ if ft == FrameType.DAY:
+ return floor < now.date() or now.hour >= 15
+ else:
+ return floor == frame
+
is_closing_call_auction_time(tm=None)
+
+
+ classmethod
+
+
+¶判断tm
指定的时间是否为收盘集合竞价时间
Fixme
+此处实现有误,收盘集合竞价时间应该还包含上午收盘时间
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | [description]. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
bool |
+ bool |
+
omicron/models/timeframe.py
@classmethod
+def is_closing_call_auction_time(
+ cls, tm: Union[datetime.datetime, Arrow] = None
+) -> bool:
+ """判断`tm`指定的时间是否为收盘集合竞价时间
+
+ Fixme:
+ 此处实现有误,收盘集合竞价时间应该还包含上午收盘时间
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ tm = tm or cls.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ minutes = tm.hour * 60 + tm.minute
+ return 15 * 60 - 3 <= minutes < 15 * 60
+
is_open_time(tm=None)
+
+
+ classmethod
+
+
+¶判断tm
指定的时间是否处在交易时间段。
交易时间段是指集合竞价时间段之外的开盘时间
+ +Examples:
+>>> TimeFrame.day_frames = np.array([20200102, 20200103, 20200106, 20200107, 20200108])
+>>> TimeFrame.is_open_time(arrow.get('2020-1-1 14:59').naive)
+False
+>>> TimeFrame.is_open_time(arrow.get('2020-1-3 14:59').naive)
+True
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | [description]. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
bool |
+ bool |
+
omicron/models/timeframe.py
@classmethod
+def is_open_time(cls, tm: Union[datetime.datetime, Arrow] = None) -> bool:
+ """判断`tm`指定的时间是否处在交易时间段。
+
+ 交易时间段是指集合竞价时间段之外的开盘时间
+
+ Examples:
+ >>> TimeFrame.day_frames = np.array([20200102, 20200103, 20200106, 20200107, 20200108])
+ >>> TimeFrame.is_open_time(arrow.get('2020-1-1 14:59').naive)
+ False
+ >>> TimeFrame.is_open_time(arrow.get('2020-1-3 14:59').naive)
+ True
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ tm = tm or arrow.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ tick = tm.hour * 60 + tm.minute
+ return tick in cls.ticks[FrameType.MIN1]
+
is_opening_call_auction_time(tm=None)
+
+
+ classmethod
+
+
+¶判断tm
指定的时间是否为开盘集合竞价时间
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ + | [description]. Defaults to None. |
+ None |
+
Returns:
+Type | +Description | +
---|---|
bool |
+ bool |
+
omicron/models/timeframe.py
@classmethod
+def is_opening_call_auction_time(
+ cls, tm: Union[Arrow, datetime.datetime] = None
+) -> bool:
+ """判断`tm`指定的时间是否为开盘集合竞价时间
+
+ Args:
+ tm : [description]. Defaults to None.
+
+ Returns:
+ bool
+ """
+ if tm is None:
+ tm = cls.now()
+
+ if not cls.is_trade_day(tm):
+ return False
+
+ minutes = tm.hour * 60 + tm.minute
+ return 9 * 60 + 15 < minutes <= 9 * 60 + 25
+
is_trade_day(dt)
+
+
+ classmethod
+
+
+¶判断dt
是否为交易日
Examples:
+>>> TimeFrame.is_trade_day(arrow.get('2020-1-1'))
+False
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
dt |
+ + | + | required | +
Returns:
+Type | +Description | +
---|---|
bool |
+ bool |
+
omicron/models/timeframe.py
@classmethod
+def is_trade_day(cls, dt: Union[datetime.date, datetime.datetime, Arrow]) -> bool:
+ """判断`dt`是否为交易日
+
+ Examples:
+ >>> TimeFrame.is_trade_day(arrow.get('2020-1-1'))
+ False
+
+ Args:
+ dt :
+
+ Returns:
+ bool
+ """
+ return cls.date2int(dt) in cls.day_frames
+
last_min_frame(day, frame_type)
+
+
+ classmethod
+
+
+¶获取day
日周期为frame_type
的结束frame。
Examples:
+>>> TimeFrame.last_min_frame(arrow.get('2020-1-5').date(), FrameType.MIN30)
+datetime.datetime(2020, 1, 3, 15, 0)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
day |
+ Union[str, Arrow, datetime.date] |
+ + | required | +
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
Union[datetime.date, datetime.datetime] |
+
|
+
omicron/models/timeframe.py
@classmethod
+def last_min_frame(
+ cls, day: Union[str, Arrow, datetime.date], frame_type: FrameType
+) -> Union[datetime.date, datetime.datetime]:
+ """获取`day`日周期为`frame_type`的结束frame。
+
+ Example:
+ >>> TimeFrame.last_min_frame(arrow.get('2020-1-5').date(), FrameType.MIN30)
+ datetime.datetime(2020, 1, 3, 15, 0)
+
+ Args:
+ day:
+ frame_type:
+
+ Returns:
+ `day`日周期为`frame_type`的结束frame
+ """
+ if isinstance(day, str):
+ day = cls.date2int(arrow.get(day).date())
+ elif isinstance(day, arrow.Arrow) or isinstance(day, datetime.datetime):
+ day = cls.date2int(day.date())
+ elif isinstance(day, datetime.date):
+ day = cls.date2int(day)
+ else:
+ raise TypeError(f"{type(day)} is not supported.")
+
+ if frame_type in cls.minute_level_frames:
+ last_close_day = cls.day_frames[cls.day_frames <= day][-1]
+ day = cls.int2date(last_close_day)
+ return datetime.datetime(day.year, day.month, day.day, hour=15, minute=0)
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} not supported")
+
minute_frames_floor(ticks, moment)
+
+
+ classmethod
+
+
+¶对于分钟级的frame,返回它们与frame刻度向下对齐后的frame及日期进位。如果需要对齐到上一个交易 +日,则进位为-1,否则为0.
+ +Examples:
+>>> ticks = [600, 630, 660, 690, 810, 840, 870, 900]
+>>> TimeFrame.minute_frames_floor(ticks, 545)
+(900, -1)
+>>> TimeFrame.minute_frames_floor(ticks, 600)
+(600, 0)
+>>> TimeFrame.minute_frames_floor(ticks, 605)
+(600, 0)
+>>> TimeFrame.minute_frames_floor(ticks, 899)
+(870, 0)
+>>> TimeFrame.minute_frames_floor(ticks, 900)
+(900, 0)
+>>> TimeFrame.minute_frames_floor(ticks, 905)
+(900, 0)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
ticks |
+ np.array or list |
+ frames刻度 |
+ required | +
moment |
+ int |
+ 整数表示的分钟数,比如900表示15:00 |
+ required | +
Returns:
+Type | +Description | +
---|---|
Tuple[int, int] |
+ tuple, the first is the new moment, the second is carry-on |
+
omicron/models/timeframe.py
@classmethod
+def minute_frames_floor(cls, ticks, moment) -> Tuple[int, int]:
+ """
+ 对于分钟级的frame,返回它们与frame刻度向下对齐后的frame及日期进位。如果需要对齐到上一个交易
+ 日,则进位为-1,否则为0.
+
+ Examples:
+ >>> ticks = [600, 630, 660, 690, 810, 840, 870, 900]
+ >>> TimeFrame.minute_frames_floor(ticks, 545)
+ (900, -1)
+ >>> TimeFrame.minute_frames_floor(ticks, 600)
+ (600, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 605)
+ (600, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 899)
+ (870, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 900)
+ (900, 0)
+ >>> TimeFrame.minute_frames_floor(ticks, 905)
+ (900, 0)
+
+ Args:
+ ticks (np.array or list): frames刻度
+ moment (int): 整数表示的分钟数,比如900表示15:00
+
+ Returns:
+ tuple, the first is the new moment, the second is carry-on
+ """
+ if moment < ticks[0]:
+ return ticks[-1], -1
+ # ’right' 相当于 ticks <= m
+ index = np.searchsorted(ticks, moment, side="right")
+ return ticks[index - 1], 0
+
month_shift(start, offset)
+
+
+ classmethod
+
+
+¶求start
所在的月移位后的frame
本函数首先将start
对齐,然后进行移位。
Examples:
+>>> TimeFrame.month_frames = np.array([20150130, 20150227, 20150331, 20150430])
+>>> TimeFrame.month_shift(arrow.get('2015-2-26').date(), 0)
+datetime.date(2015, 1, 30)
+
>>> TimeFrame.month_shift(arrow.get('2015-2-27').date(), 0)
+datetime.date(2015, 2, 27)
+
>>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 0)
+datetime.date(2015, 2, 27)
+
>>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 1)
+datetime.date(2015, 3, 31)
+
Returns:
+Type | +Description | +
---|---|
datetime.date |
+ 移位后的日期 |
+
omicron/models/timeframe.py
@classmethod
+def month_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """求`start`所在的月移位后的frame
+
+ 本函数首先将`start`对齐,然后进行移位。
+ Examples:
+ >>> TimeFrame.month_frames = np.array([20150130, 20150227, 20150331, 20150430])
+ >>> TimeFrame.month_shift(arrow.get('2015-2-26').date(), 0)
+ datetime.date(2015, 1, 30)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-2-27').date(), 0)
+ datetime.date(2015, 2, 27)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 0)
+ datetime.date(2015, 2, 27)
+
+ >>> TimeFrame.month_shift(arrow.get('2015-3-1').date(), 1)
+ datetime.date(2015, 3, 31)
+
+ Returns:
+ 移位后的日期
+ """
+ start = cls.date2int(start)
+ return cls.int2date(ext.shift(cls.month_frames, start, offset))
+
replace_date(dtm, dt)
+
+
+ classmethod
+
+
+¶将dtm
变量的日期更换为dt
指定的日期
Examples:
+>>> TimeFrame.replace_date(arrow.get('2020-1-1 13:49').datetime, datetime.date(2019, 1,1))
+datetime.datetime(2019, 1, 1, 13, 49)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
dtm |
+ datetime.datetime |
+ [description] |
+ required | +
dt |
+ datetime.date |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
datetime.datetime |
+ 变换后的时间 |
+
omicron/models/timeframe.py
@classmethod
+def replace_date(
+ cls, dtm: datetime.datetime, dt: datetime.date
+) -> datetime.datetime:
+ """将`dtm`变量的日期更换为`dt`指定的日期
+
+ Example:
+ >>> TimeFrame.replace_date(arrow.get('2020-1-1 13:49').datetime, datetime.date(2019, 1,1))
+ datetime.datetime(2019, 1, 1, 13, 49)
+
+ Args:
+ dtm (datetime.datetime): [description]
+ dt (datetime.date): [description]
+
+ Returns:
+ 变换后的时间
+ """
+ return datetime.datetime(
+ dt.year, dt.month, dt.day, dtm.hour, dtm.minute, dtm.second, dtm.microsecond
+ )
+
resample_frames(trade_days, frame_type)
+
+
+ classmethod
+
+
+¶将从行情服务器获取的交易日历重采样,生成周帧和月线帧
+ +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
trade_days |
+ Iterable |
+ [description] |
+ required | +
frame_type |
+ FrameType |
+ [description] |
+ required | +
Returns:
+Type | +Description | +
---|---|
List[int] |
+ 重采样后的日期列表,日期用整数表示 |
+
omicron/models/timeframe.py
@classmethod
+def resample_frames(
+ cls, trade_days: Iterable[datetime.date], frame_type: FrameType
+) -> List[int]:
+ """将从行情服务器获取的交易日历重采样,生成周帧和月线帧
+
+ Args:
+ trade_days (Iterable): [description]
+ frame_type (FrameType): [description]
+
+ Returns:
+ List[int]: 重采样后的日期列表,日期用整数表示
+ """
+ if frame_type == FrameType.WEEK:
+ weeks = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.weekday() < last.weekday() or (cur - last).days >= 7:
+ weeks.append(last)
+ last = cur
+
+ if weeks[-1] < last:
+ weeks.append(last)
+
+ return weeks
+ elif frame_type == FrameType.MONTH:
+ months = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.day < last.day:
+ months.append(last)
+ last = cur
+ months.append(last)
+
+ return months
+ elif frame_type == FrameType.QUARTER:
+ quarters = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if last.month % 3 == 0:
+ if cur.month > last.month or cur.year > last.year:
+ quarters.append(last)
+ last = cur
+ quarters.append(last)
+
+ return quarters
+ elif frame_type == FrameType.YEAR:
+ years = []
+ last = trade_days[0]
+ for cur in trade_days:
+ if cur.year > last.year:
+ years.append(last)
+ last = cur
+ years.append(last)
+
+ return years
+ else: # pragma: no cover
+ raise ValueError(f"Unsupported FrameType: {frame_type}")
+
service_degrade()
+
+
+ classmethod
+
+
+¶当cache中不存在日历时,启用随omicron版本一起发行时自带的日历。
+注意:随omicron版本一起发行时自带的日历很可能不是最新的,并且可能包含错误。比如,存在这样的情况,在本版本的omicron发行时,日历更新到了2021年12月31日,在这之前的日历都是准确的,但在此之后的日历,则有可能出现错误。因此,只应该在特殊的情况下(比如测试)调用此方法,以获得一个降级的服务。
+ +omicron/models/timeframe.py
@classmethod
+def service_degrade(cls):
+ """当cache中不存在日历时,启用随omicron版本一起发行时自带的日历。
+
+ 注意:随omicron版本一起发行时自带的日历很可能不是最新的,并且可能包含错误。比如,存在这样的情况,在本版本的omicron发行时,日历更新到了2021年12月31日,在这之前的日历都是准确的,但在此之后的日历,则有可能出现错误。因此,只应该在特殊的情况下(比如测试)调用此方法,以获得一个降级的服务。
+ """
+ _dir = os.path.dirname(__file__)
+ file = os.path.join(_dir, "..", "config", "calendar.json")
+ with open(file, "r") as f:
+ data = json.load(f)
+ for k, v in data.items():
+ setattr(cls, k, np.array(v))
+
shift(moment, n, frame_type)
+
+
+ classmethod
+
+
+¶将指定的moment移动N个frame_type
位置。
当N为负数时,意味着向前移动;当N为正数时,意味着向后移动。如果n为零,意味着移动到最接近 +的一个已结束的frame。
+如果moment没有对齐到frame_type对应的时间,将首先进行对齐。
+See also:
+ + +Examples:
+>>> TimeFrame.shift(datetime.date(2020, 1, 3), 1, FrameType.DAY)
+datetime.date(2020, 1, 6)
+
>>> TimeFrame.shift(datetime.datetime(2020, 1, 6, 11), 1, FrameType.MIN30)
+datetime.datetime(2020, 1, 6, 11, 30)
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
moment |
+ Union[Arrow, datetime.date, datetime.datetime] |
+ + | required | +
n |
+ int |
+ + | required | +
frame_type |
+ FrameType |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
Union[datetime.date, datetime.datetime] |
+ 移位后的Frame |
+
omicron/models/timeframe.py
@classmethod
+def shift(
+ cls,
+ moment: Union[Arrow, datetime.date, datetime.datetime],
+ n: int,
+ frame_type: FrameType,
+) -> Union[datetime.date, datetime.datetime]:
+ """将指定的moment移动N个`frame_type`位置。
+
+ 当N为负数时,意味着向前移动;当N为正数时,意味着向后移动。如果n为零,意味着移动到最接近
+ 的一个已结束的frame。
+
+ 如果moment没有对齐到frame_type对应的时间,将首先进行对齐。
+
+ See also:
+
+ - [day_shift][omicron.models.timeframe.TimeFrame.day_shift]
+ - [week_shift][omicron.models.timeframe.TimeFrame.week_shift]
+ - [month_shift][omicron.models.timeframe.TimeFrame.month_shift]
+
+ Examples:
+ >>> TimeFrame.shift(datetime.date(2020, 1, 3), 1, FrameType.DAY)
+ datetime.date(2020, 1, 6)
+
+ >>> TimeFrame.shift(datetime.datetime(2020, 1, 6, 11), 1, FrameType.MIN30)
+ datetime.datetime(2020, 1, 6, 11, 30)
+
+
+ Args:
+ moment:
+ n:
+ frame_type:
+
+ Returns:
+ 移位后的Frame
+ """
+ if frame_type == FrameType.DAY:
+ return cls.day_shift(moment, n)
+
+ elif frame_type == FrameType.WEEK:
+ return cls.week_shift(moment, n)
+ elif frame_type == FrameType.MONTH:
+ return cls.month_shift(moment, n)
+ elif frame_type in [
+ FrameType.MIN1,
+ FrameType.MIN5,
+ FrameType.MIN15,
+ FrameType.MIN30,
+ FrameType.MIN60,
+ ]:
+ tm = moment.hour * 60 + moment.minute
+
+ new_tick_pos = cls.ticks[frame_type].index(tm) + n
+ days = new_tick_pos // len(cls.ticks[frame_type])
+ min_part = new_tick_pos % len(cls.ticks[frame_type])
+
+ date_part = cls.day_shift(moment.date(), days)
+ minutes = cls.ticks[frame_type][min_part]
+ h, m = minutes // 60, minutes % 60
+ return datetime.datetime(
+ date_part.year,
+ date_part.month,
+ date_part.day,
+ h,
+ m,
+ tzinfo=moment.tzinfo,
+ )
+ else: # pragma: no cover
+ raise ValueError(f"{frame_type} is not supported.")
+
time2int(tm)
+
+
+ classmethod
+
+
+¶将时间类型转换为整数类型
+tm可以是Arrow类型,也可以是datetime.datetime或者任何其它类型,只要它有year,month...等 +属性
+ +Examples:
+>>> TimeFrame.time2int(datetime.datetime(2020, 5, 1, 15))
+202005011500
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
tm |
+ Union[datetime.datetime, Arrow] |
+ + | required | +
Returns:
+Type | +Description | +
---|---|
int |
+ 转换后的整数,比如2020050115 |
+
omicron/models/timeframe.py
@classmethod
+def time2int(cls, tm: Union[datetime.datetime, Arrow]) -> int:
+ """将时间类型转换为整数类型
+
+ tm可以是Arrow类型,也可以是datetime.datetime或者任何其它类型,只要它有year,month...等
+ 属性
+ Examples:
+ >>> TimeFrame.time2int(datetime.datetime(2020, 5, 1, 15))
+ 202005011500
+
+ Args:
+ tm:
+
+ Returns:
+ 转换后的整数,比如2020050115
+ """
+ return int(f"{tm.year:04}{tm.month:02}{tm.day:02}{tm.hour:02}{tm.minute:02}")
+
week_shift(start, offset)
+
+
+ classmethod
+
+
+¶对指定日期按周线帧进行前后移位操作
+参考 omicron.models.timeframe.TimeFrame.day_shift
+ +Examples:
+>>> TimeFrame.week_frames = np.array([20200103, 20200110, 20200117, 20200123,20200207, 20200214])
+>>> moment = arrow.get('2020-1-21').date()
+>>> TimeFrame.week_shift(moment, 1)
+datetime.date(2020, 1, 23)
+
>>> TimeFrame.week_shift(moment, 0)
+datetime.date(2020, 1, 17)
+
>>> TimeFrame.week_shift(moment, -1)
+datetime.date(2020, 1, 10)
+
Returns:
+Type | +Description | +
---|---|
datetime.date |
+ 移位后的日期 |
+
omicron/models/timeframe.py
@classmethod
+def week_shift(cls, start: datetime.date, offset: int) -> datetime.date:
+ """对指定日期按周线帧进行前后移位操作
+
+ 参考 [omicron.models.timeframe.TimeFrame.day_shift][]
+ Examples:
+ >>> TimeFrame.week_frames = np.array([20200103, 20200110, 20200117, 20200123,20200207, 20200214])
+ >>> moment = arrow.get('2020-1-21').date()
+ >>> TimeFrame.week_shift(moment, 1)
+ datetime.date(2020, 1, 23)
+
+ >>> TimeFrame.week_shift(moment, 0)
+ datetime.date(2020, 1, 17)
+
+ >>> TimeFrame.week_shift(moment, -1)
+ datetime.date(2020, 1, 10)
+
+ Returns:
+ 移位后的日期
+ """
+ start = cls.date2int(start)
+ return cls.int2date(ext.shift(cls.week_frames, start, offset))
+
rendering: +heading_level: 1
+ + + + + + +在apscheduler.triggers的基础上提供了FrameTrigger和IntervalTrigger,使得它们只在交易日(或者 +基于交易日+延时)时激发。
+ + + +
+FrameTrigger (BaseTrigger)
+
+
+
+
+¶A cron like trigger fires on each valid Frame
+ +omicron/core/triggers.py
class FrameTrigger(BaseTrigger):
+ """
+ A cron like trigger fires on each valid Frame
+ """
+
+ def __init__(self, frame_type: Union[str, FrameType], jitter: str = None):
+ """构造函数
+
+ jitter的格式用正则式表达为`r"([-]?)(\\d+)([mshd])"`,其中第一组为符号,'-'表示提前;
+ 第二组为数字,第三组为单位,可以为`m`(分钟), `s`(秒), `h`(小时),`d`(天)。
+
+ 下面的示例构造了一个只在交易日,每30分钟触发一次,每次提前15秒触的trigger。即它的触发时
+ 间是每个交易日的09:29:45, 09:59:45, ...
+
+ Examples:
+ >>> FrameTrigger(FrameType.MIN30, '-15s') # doctest: +ELLIPSIS
+ <omicron.core.triggers.FrameTrigger object at 0x...>
+
+ Args:
+ frame_type:
+ jitter: 单位秒。其中offset必须在一个FrameType的长度以内
+ """
+ self.frame_type = FrameType(frame_type)
+ if jitter is None:
+ _jitter = 0
+ else:
+ matched = re.match(r"([-]?)(\d+)([mshd])", jitter)
+ if matched is None: # pragma: no cover
+ raise ValueError(
+ "malformed. jitter should be [-](number)(unit), "
+ "for example, -30m, or 30s"
+ )
+ sign, num, unit = matched.groups()
+ num = int(num)
+ if unit.lower() == "m":
+ _jitter = 60 * num
+ elif unit.lower() == "s":
+ _jitter = num
+ elif unit.lower() == "h":
+ _jitter = 3600 * num
+ elif unit.lower() == "d":
+ _jitter = 3600 * 24 * num
+ else: # pragma: no cover
+ raise ValueError("bad time unit. only s,h,m,d is acceptable")
+
+ if sign == "-":
+ _jitter = -_jitter
+
+ self.jitter = datetime.timedelta(seconds=_jitter)
+ if (
+ frame_type == FrameType.MIN1
+ and abs(_jitter) >= 60
+ or frame_type == FrameType.MIN5
+ and abs(_jitter) >= 300
+ or frame_type == FrameType.MIN15
+ and abs(_jitter) >= 900
+ or frame_type == FrameType.MIN30
+ and abs(_jitter) >= 1800
+ or frame_type == FrameType.MIN60
+ and abs(_jitter) >= 3600
+ or frame_type == FrameType.DAY
+ and abs(_jitter) >= 24 * 3600
+ # it's still not allowed if offset > week, month, etc. Would anybody
+ # really specify an offset longer than that?
+ ):
+ raise ValueError("offset must be less than frame length")
+
+ def __str__(self):
+ return f"{self.__class__.__name__}:{self.frame_type.value}:{self.jitter}"
+
+ def get_next_fire_time(
+ self,
+ previous_fire_time: Union[datetime.date, datetime.datetime],
+ now: Union[datetime.date, datetime.datetime],
+ ):
+ """"""
+ ft = self.frame_type
+
+ # `now` is timezone aware, while ceiling isn't
+ now = now.replace(tzinfo=None)
+ next_tick = now
+ next_frame = TimeFrame.ceiling(now, ft)
+ while next_tick <= now:
+ if ft in TimeFrame.day_level_frames:
+ next_tick = TimeFrame.combine_time(next_frame, 15) + self.jitter
+ else:
+ next_tick = next_frame + self.jitter
+
+ if next_tick > now:
+ tz = tzlocal.get_localzone()
+ return next_tick.astimezone(tz)
+ else:
+ next_frame = TimeFrame.shift(next_frame, 1, ft)
+
__init__(self, frame_type, jitter=None)
+
+
+ special
+
+
+¶构造函数
+jitter的格式用正则式表达为r"([-]?)(\d+)([mshd])"
,其中第一组为符号,'-'表示提前;
+第二组为数字,第三组为单位,可以为m
(分钟), s
(秒), h
(小时),d
(天)。
下面的示例构造了一个只在交易日,每30分钟触发一次,每次提前15秒触的trigger。即它的触发时 +间是每个交易日的09:29:45, 09:59:45, ...
+ +Examples:
+>>> FrameTrigger(FrameType.MIN30, '-15s')
+<omicron.core.triggers.FrameTrigger object at 0x...>
+
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
frame_type |
+ Union[str, coretypes.types.FrameType] |
+ + | required | +
jitter |
+ str |
+ 单位秒。其中offset必须在一个FrameType的长度以内 |
+ None |
+
omicron/core/triggers.py
def __init__(self, frame_type: Union[str, FrameType], jitter: str = None):
+ """构造函数
+
+ jitter的格式用正则式表达为`r"([-]?)(\\d+)([mshd])"`,其中第一组为符号,'-'表示提前;
+ 第二组为数字,第三组为单位,可以为`m`(分钟), `s`(秒), `h`(小时),`d`(天)。
+
+ 下面的示例构造了一个只在交易日,每30分钟触发一次,每次提前15秒触的trigger。即它的触发时
+ 间是每个交易日的09:29:45, 09:59:45, ...
+
+ Examples:
+ >>> FrameTrigger(FrameType.MIN30, '-15s') # doctest: +ELLIPSIS
+ <omicron.core.triggers.FrameTrigger object at 0x...>
+
+ Args:
+ frame_type:
+ jitter: 单位秒。其中offset必须在一个FrameType的长度以内
+ """
+ self.frame_type = FrameType(frame_type)
+ if jitter is None:
+ _jitter = 0
+ else:
+ matched = re.match(r"([-]?)(\d+)([mshd])", jitter)
+ if matched is None: # pragma: no cover
+ raise ValueError(
+ "malformed. jitter should be [-](number)(unit), "
+ "for example, -30m, or 30s"
+ )
+ sign, num, unit = matched.groups()
+ num = int(num)
+ if unit.lower() == "m":
+ _jitter = 60 * num
+ elif unit.lower() == "s":
+ _jitter = num
+ elif unit.lower() == "h":
+ _jitter = 3600 * num
+ elif unit.lower() == "d":
+ _jitter = 3600 * 24 * num
+ else: # pragma: no cover
+ raise ValueError("bad time unit. only s,h,m,d is acceptable")
+
+ if sign == "-":
+ _jitter = -_jitter
+
+ self.jitter = datetime.timedelta(seconds=_jitter)
+ if (
+ frame_type == FrameType.MIN1
+ and abs(_jitter) >= 60
+ or frame_type == FrameType.MIN5
+ and abs(_jitter) >= 300
+ or frame_type == FrameType.MIN15
+ and abs(_jitter) >= 900
+ or frame_type == FrameType.MIN30
+ and abs(_jitter) >= 1800
+ or frame_type == FrameType.MIN60
+ and abs(_jitter) >= 3600
+ or frame_type == FrameType.DAY
+ and abs(_jitter) >= 24 * 3600
+ # it's still not allowed if offset > week, month, etc. Would anybody
+ # really specify an offset longer than that?
+ ):
+ raise ValueError("offset must be less than frame length")
+
+TradeTimeIntervalTrigger (BaseTrigger)
+
+
+
+
+¶只在交易时间触发的固定间隔的trigger
+ +omicron/core/triggers.py
class TradeTimeIntervalTrigger(BaseTrigger):
+ """只在交易时间触发的固定间隔的trigger"""
+
+ def __init__(self, interval: str):
+ """构造函数
+
+ interval的格式用正则表达式表示为 `r"(\\d+)([mshd])"` 。其中第一组为数字,第二组为单位。有效的
+ `interval`如 1 ,表示每1小时触发一次,则该触发器将在交易日的10:30, 11:30, 14:00和
+ 15:00各触发一次
+
+ Args:
+ interval : [description]
+
+ Raises:
+ ValueError: [description]
+ """
+ matched = re.match(r"(\d+)([mshd])", interval)
+ if matched is None:
+ raise ValueError(f"malform interval {interval}")
+
+ interval, unit = matched.groups()
+ interval = int(interval)
+ unit = unit.lower()
+ if unit == "s":
+ self.interval = datetime.timedelta(seconds=interval)
+ elif unit == "m":
+ self.interval = datetime.timedelta(minutes=interval)
+ elif unit == "h":
+ self.interval = datetime.timedelta(hours=interval)
+ elif unit == "d":
+ self.interval = datetime.timedelta(days=interval)
+ else:
+ self.interval = datetime.timedelta(seconds=interval)
+
+ def __str__(self):
+ return f"{self.__class__.__name__}:{self.interval.seconds}"
+
+ def get_next_fire_time(
+ self,
+ previous_fire_time: Optional[datetime.datetime],
+ now: Optional[datetime.datetime],
+ ):
+ """"""
+ if previous_fire_time is not None:
+ fire_time = previous_fire_time + self.interval
+ else:
+ fire_time = now
+
+ if TimeFrame.date2int(fire_time.date()) not in TimeFrame.day_frames:
+ ft = TimeFrame.day_shift(now, 1)
+ fire_time = datetime.datetime(
+ ft.year, ft.month, ft.day, 9, 30, tzinfo=fire_time.tzinfo
+ )
+ return fire_time
+
+ minutes = fire_time.hour * 60 + fire_time.minute
+
+ if minutes < 570:
+ fire_time = fire_time.replace(hour=9, minute=30, second=0, microsecond=0)
+ elif 690 < minutes < 780:
+ fire_time = fire_time.replace(hour=13, minute=0, second=0, microsecond=0)
+ elif minutes > 900:
+ ft = TimeFrame.day_shift(fire_time, 1)
+ fire_time = datetime.datetime(
+ ft.year, ft.month, ft.day, 9, 30, tzinfo=fire_time.tzinfo
+ )
+
+ return fire_time
+
__init__(self, interval)
+
+
+ special
+
+
+¶构造函数
+interval的格式用正则表达式表示为 r"(\d+)([mshd])"
。其中第一组为数字,第二组为单位。有效的
+interval
如 1 ,表示每1小时触发一次,则该触发器将在交易日的10:30, 11:30, 14:00和
+15:00各触发一次
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
interval |
+ + | [description] |
+ required | +
Exceptions:
+Type | +Description | +
---|---|
ValueError |
+ [description] |
+
omicron/core/triggers.py
def __init__(self, interval: str):
+ """构造函数
+
+ interval的格式用正则表达式表示为 `r"(\\d+)([mshd])"` 。其中第一组为数字,第二组为单位。有效的
+ `interval`如 1 ,表示每1小时触发一次,则该触发器将在交易日的10:30, 11:30, 14:00和
+ 15:00各触发一次
+
+ Args:
+ interval : [description]
+
+ Raises:
+ ValueError: [description]
+ """
+ matched = re.match(r"(\d+)([mshd])", interval)
+ if matched is None:
+ raise ValueError(f"malform interval {interval}")
+
+ interval, unit = matched.groups()
+ interval = int(interval)
+ unit = unit.lower()
+ if unit == "s":
+ self.interval = datetime.timedelta(seconds=interval)
+ elif unit == "m":
+ self.interval = datetime.timedelta(minutes=interval)
+ elif unit == "h":
+ self.interval = datetime.timedelta(hours=interval)
+ elif unit == "d":
+ self.interval = datetime.timedelta(days=interval)
+ else:
+ self.interval = datetime.timedelta(seconds=interval)
+
\n {translation(\"search.result.term.missing\")}: {...missing}\n
\n }\n