diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d07ba1f881..e196c124bd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,7 @@ jobs: - name: Test data downloads and examples run: | - python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn # cd examples # estimator -c estimator/estimator_config.yaml # jupyter nbconvert --execute estimator/analyze_from_estimator.ipynb --to html \ No newline at end of file diff --git a/README.md b/README.md index 31eb7f532a..5ff9b624b1 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ Also, users can install ``Qlib`` by the source code according to the following s ## Data Preparation Load and prepare data by running the following code: ```bash - python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn ``` This dataset is created by public data collected by [crawler scripts](scripts/data_collector/), which have been released in diff --git a/docs/component/data.rst b/docs/component/data.rst index 6b813b39e6..ba4cc0053c 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -34,7 +34,7 @@ Qlib Format Dataset .. code-block:: bash - python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn After running the above command, users can find china-stock data in Qlib format in the ``~/.qlib/csv_data/cn_data`` directory. @@ -59,7 +59,7 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv .. code-block:: bash - python scripts/dump_bin.py dump --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor + python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor After conversion, users can find their Qlib format data in the directory `~/.qlib/qlib_data/my_data`. diff --git a/docs/introduction/quick.rst b/docs/introduction/quick.rst index df4b84062f..9fff8cb3fe 100644 --- a/docs/introduction/quick.rst +++ b/docs/introduction/quick.rst @@ -40,7 +40,7 @@ Load and prepare data by running the following code: .. code-block:: - python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn This dataset is created by public data collected by crawler scripts in ``scripts/data_collector/``, which have been released in the same repository. Users could create the same dataset with it. diff --git a/docs/start/initialization.rst b/docs/start/initialization.rst index bcb09925ef..b975a18c5a 100644 --- a/docs/start/initialization.rst +++ b/docs/start/initialization.rst @@ -14,7 +14,7 @@ Please follow the steps below to initialize ``Qlib``. - Download and prepare the Data: execute the following command to download stock data. Please pay `attention` that the data is collected from `Yahoo Finance `_ and the data might not be perfect. We recommend users to prepare their own data if they have high-quality datasets. Please refer to `Data <../component/data.html#converting-csv-format-into-qlib-format>` for more information about customized dataset. .. code-block:: bash - python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn Please refer to `Data Preparation <../component/data.html#data-preparation>`_ for more information about `get_data.py`, diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 9812864af1..d32b251ded 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -56,7 +56,24 @@ def __init__( end_time=None, data_loader: Tuple[dict, str, DataLoader] = None, init_data=True, + fetch_orig=True, ): + """ + Parameters + ---------- + instruments : + The stock list to retrive + start_time : + start_time of the original data + end_time : + end_time of the original data + data_loader : Tuple[dict, str, DataLoader] + data loader to load the data + init_data : + intialize the original data in the constructor + fetch_orig : bool + Return the original data instead of copy if possible + """ # Set logger self.logger = get_module_logger("DataHandler") @@ -72,6 +89,7 @@ def __init__( self.instruments = instruments self.start_time = start_time self.end_time = end_time + self.fetch_orig = fetch_orig if init_data: with TimeInspector.logt("Init data"): self.init() @@ -138,7 +156,7 @@ def fetch( ------- pd.DataFrame: """ - df = fetch_df_by_index(self._data, selector, level) + df = fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig) df = self._fetch_df_by_col(df, col_set) if squeeze: # squeeze columns @@ -269,8 +287,10 @@ def __init__( for pname in "infer_processors", "learn_processors": for proc in locals()[pname]: getattr(self, pname).append( - init_instance_by_config(proc, processor_module, accept_types=(processor_module.Processor,)) - ) + init_instance_by_config( + proc, + None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module, + accept_types=processor_module.Processor)) self.process_type = process_type super().__init__(instruments, start_time, end_time, data_loader, **kwargs) @@ -354,15 +374,16 @@ def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False): # init raw data super().init(enable_cache=enable_cache) - if init_type == DataHandlerLP.IT_FIT_IND: - self.fit() - self.process_data() - elif init_type == DataHandlerLP.IT_LS: - self.process_data() - elif init_type == DataHandlerLP.IT_FIT_SEQ: - self.fit_process_data() - else: - raise NotImplementedError(f"This type of input is not supported") + with TimeInspector.logt("fit & process data"): + if init_type == DataHandlerLP.IT_FIT_IND: + self.fit() + self.process_data() + elif init_type == DataHandlerLP.IT_LS: + self.process_data() + elif init_type == DataHandlerLP.IT_FIT_SEQ: + self.fit_process_data() + else: + raise NotImplementedError(f"This type of input is not supported") # TODO: Be able to cache handler data. Save the memory for data processing @@ -396,7 +417,7 @@ def fetch( pd.DataFrame: """ df = self._get_df_by_key(data_key) - df = fetch_df_by_index(df, selector, level) + df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig) return self._fetch_df_by_col(df, col_set) def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list: diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 8ee199bc0d..85a5e8389c 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -32,7 +32,7 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int: def fetch_df_by_index( - df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int] + df: pd.DataFrame, selector: Union[pd.Timestamp, slice, str, list], level: Union[str, int], fetch_orig=True, ) -> pd.DataFrame: """ fetch data from `data` with `selector` and `level` @@ -52,6 +52,11 @@ def fetch_df_by_index( idx_slc = (selector, slice(None, None)) if get_level_index(df, level) == 1: idx_slc = idx_slc[1], idx_slc[0] - return df.loc[ - pd.IndexSlice[idx_slc], - ] # This could be faster than df.loc(axis=0)[idx_slc] + if fetch_orig: + for slc in idx_slc: + if slc != slice(None, None): + return df.loc[pd.IndexSlice[idx_slc],] + else: + return df + else: + return df.loc[pd.IndexSlice[idx_slc],] diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index b57879d0e9..0f721e0352 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -5,9 +5,9 @@ from . import R from .recorder import Recorder from ..log import get_module_logger - logger = get_module_logger("workflow", "INFO") + # function to handle the experiment when unusual program ending occurs def experiment_exit_handler(): """ @@ -31,9 +31,11 @@ def experiment_exception_hook(type, value, tb): value: Exception's value tb: Exception's traceback """ - logger.error("An exception has been raised.") + logger.error(f"An exception has been raised[{type.__name__}: {value}].") + + # Same as original format traceback.print_tb(tb) - print(f"{type}: {value}") + print(f"{type.__name__}: {value}") R.end_exp(recorder_status=Recorder.STATUS_FA) diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000000..98b01e0c35 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,61 @@ + +- [Download Qlib Data](#Download-Qlib-Data) + - [Download CN Data](#Download-CN-Data) + - [Downlaod US Data](#Downlaod-US-Data) + - [Download CN Simple Data](#Download-CN-Simple-Data) + - [Help](#Help) +- [Using in Qlib](#Using-in-Qlib) + - [US data](#US-data) + - [CN data](#CN-data) + + +## Download Qlib Data + + +### Download CN Data + +```bash +python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn +``` + +### Downlaod US Data + +> The US stock code contains 'PRN', and the directory cannot be created on Windows system: https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows + +```bash +python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us +``` + +### Download CN Simple Data + +```bash +python get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --region cn +``` + +### Help + +```bash +python get_data.py qlib_data --help +``` + +## Using in Qlib +> For more information: https://qlib.readthedocs.io/en/latest/start/initialization.html + + +### US data + +```python +import qlib +from qlib.config import REG_US +provider_uri = "~/.qlib/qlib_data/us_data" # target_dir +qlib.init(provider_uri=provider_uri, region=REG_US) +``` + +### CN data + +```python +import qlib +from qlib.config import REG_CN +provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir +qlib.init(provider_uri=provider_uri, region=REG_CN) +``` diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py new file mode 100644 index 0000000000..7c2ceccdae --- /dev/null +++ b/scripts/check_dump_bin.py @@ -0,0 +1,144 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +from concurrent.futures import ProcessPoolExecutor + +import qlib +from qlib.data import D + +import fire +import datacompy +import pandas as pd +from tqdm import tqdm +from loguru import logger + + +class CheckBin: + + NOT_IN_FEATURES = "not in features" + COMPARE_FALSE = "compare False" + COMPARE_TRUE = "compare True" + COMPARE_ERROR = "compare error" + + def __init__( + self, + qlib_dir: str, + csv_path: str, + check_fields: str = None, + freq: str = "day", + symbol_field_name: str = "symbol", + date_field_name: str = "date", + file_suffix: str = ".csv", + max_workers: int = 16, + ): + """ + + Parameters + ---------- + qlib_dir : str + qlib dir + csv_path : str + origin csv path + check_fields : str, optional + check fields, by default None, check qlib_dir/features//*..bin + freq : str, optional + freq, value from ["day", "1m"] + symbol_field_name: str, optional + symbol field name, by default "symbol" + date_field_name: str, optional + date field name, by default "date" + file_suffix: str, optional + csv file suffix, by default ".csv" + max_workers: int, optional + max workers, by default 16 + """ + self.qlib_dir = Path(qlib_dir).expanduser() + bin_path_list = list(self.qlib_dir.joinpath("features").iterdir()) + self.qlib_symbols = sorted(map(lambda x: x.name.lower(), bin_path_list)) + qlib.init( + provider_uri=str(self.qlib_dir.resolve()), + mount_path=str(self.qlib_dir.resolve()), + auto_mount=False, + redis_port=-1, + ) + csv_path = Path(csv_path).expanduser() + self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path]) + + if check_fields is None: + check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin"))) + else: + check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields + self.check_fields = list(map(lambda x: x.strip(), check_fields)) + self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields)) + self.max_workers = max_workers + self.symbol_field_name = symbol_field_name + self.date_field_name = date_field_name + self.freq = freq + self.file_suffix = file_suffix + + def _compare(self, file_path: Path): + symbol = file_path.name.strip(self.file_suffix) + if symbol.lower() not in self.qlib_symbols: + return self.NOT_IN_FEATURES + # qlib data + qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq) + qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True) + # csv data + origin_df = pd.read_csv(file_path) + origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name]) + if self.symbol_field_name not in origin_df.columns: + origin_df[self.symbol_field_name] = symbol + origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True) + origin_df.index.names = qlib_df.index.names + try: + compare = datacompy.Compare( + origin_df, + qlib_df, + on_index=True, + abs_tol=1e-08, # Optional, defaults to 0 + rel_tol=1e-05, # Optional, defaults to 0 + df1_name="Original", # Optional, defaults to 'df1' + df2_name="New", # Optional, defaults to 'df2' + ) + _r = compare.matches(ignore_extra_columns=True) + return self.COMPARE_TRUE if _r else self.COMPARE_FALSE + except Exception as e: + logger.warning(f"{symbol} compare error: {e}") + return self.COMPARE_ERROR + + def check(self): + """Check whether the bin file after ``dump_bin.py`` is executed is consistent with the original csv file data + + """ + logger.info("start check......") + + error_list = [] + not_in_features = [] + compare_false = [] + with tqdm(total=len(self.csv_files)) as p_bar: + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)): + symbol = file_path.name.strip(self.file_suffix) + if _check_res == self.NOT_IN_FEATURES: + not_in_features.append(symbol) + elif _check_res == self.COMPARE_ERROR: + error_list.append(symbol) + elif _check_res == self.COMPARE_FALSE: + compare_false.append(symbol) + p_bar.update() + + logger.info("end of check......") + if error_list: + logger.warning(f"compare error: {error_list}") + if not_in_features: + logger.warning(f"not in features: {not_in_features}") + if compare_false: + logger.warning(f"compare False: {compare_false}") + logger.info( + f"total {len(self.csv_files)}, {len(error_list)} errors, {len(not_in_features)} not in features, {len(compare_false)} compare false" + ) + + +if __name__ == "__main__": + fire.Fire(CheckBin) diff --git a/scripts/data_collector/cn_index/README.md b/scripts/data_collector/cn_index/README.md new file mode 100644 index 0000000000..82f17eb5cb --- /dev/null +++ b/scripts/data_collector/cn_index/README.md @@ -0,0 +1,22 @@ +# CSI300/CSI100 History Companies Collection + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + +# parse new companies +python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + +# index_name support: CSI300, CSI100 +# help +python collector.py --help +``` + diff --git a/scripts/data_collector/csi/collector.py b/scripts/data_collector/cn_index/collector.py similarity index 58% rename from scripts/data_collector/csi/collector.py rename to scripts/data_collector/cn_index/collector.py index af10c12d68..5af9785ec3 100644 --- a/scripts/data_collector/csi/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -4,8 +4,9 @@ import re import abc import sys -import bisect +import importlib from io import BytesIO +from typing import List from pathlib import Path import fire @@ -16,7 +17,9 @@ CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.utils import get_hs_calendar_list as get_calendar_list + +from data_collector.index import IndexBase +from data_collector.utils import get_calendar_list, get_trading_date_by_shift NEW_COMPANIES_URL = "http://www.csindex.com.cn/uploads/file/autofile/cons/{index_code}cons.xls" @@ -24,64 +27,48 @@ INDEX_CHANGES_URL = "http://www.csindex.com.cn/zh-CN/search/total?key=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC%E8%82%A1%E7%9A%84%E5%85%AC%E5%91%8A" -class CSIIndex: - - REMOVE = "remove" - ADD = "add" - - def __init__(self, qlib_dir=None): - """ - - Parameters - ---------- - qlib_dir: str - qlib data dir, default "Path(__file__).parent/qlib_data" - """ - - if qlib_dir is None: - qlib_dir = CUR_DIR.joinpath("qlib_data") - self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments") - self.instruments_dir.mkdir(exist_ok=True, parents=True) - self._calendar_list = None - - self.cache_dir = Path("~/.cache/csi").expanduser().resolve() - self.cache_dir.mkdir(exist_ok=True, parents=True) - +class CSIIndex(IndexBase): @property - def calendar_list(self) -> list: + def calendar_list(self) -> List[pd.Timestamp]: """get history trading date Returns ------- + calendar list """ return get_calendar_list(bench_code=self.index_name.upper()) @property - def new_companies_url(self): + def new_companies_url(self) -> str: return NEW_COMPANIES_URL.format(index_code=self.index_code) @property - def changes_url(self): + def changes_url(self) -> str: return INDEX_CHANGES_URL @property @abc.abstractmethod def bench_start_date(self) -> pd.Timestamp: - raise NotImplementedError() - - @property - @abc.abstractmethod - def index_code(self): - raise NotImplementedError() + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") @property @abc.abstractmethod - def index_name(self): - raise NotImplementedError() + def index_code(self) -> str: + """ + Returns + ------- + index code + """ + raise NotImplementedError("rewrite index_code") @property @abc.abstractmethod - def html_table_index(self): + def html_table_index(self) -> int: """Which table of changes in html CSI300: 0 @@ -90,33 +77,19 @@ def html_table_index(self): """ raise NotImplementedError() - def _get_trading_date_by_shift(self, trading_date: pd.Timestamp, shift=1): - """get trading date by shift - - Parameters - ---------- - shift : int - shift, default is 1 - - trading_date : pd.Timestamp - trading date - Returns - ------- - - """ - left_index = bisect.bisect_left(self.calendar_list, trading_date) - try: - res = self.calendar_list[left_index + shift] - except IndexError: - res = trading_date - return res - - def _get_changes(self) -> pd.DataFrame: + def get_changes(self) -> pd.DataFrame: """get companies changes Returns ------- - + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] """ logger.info("get companies changes......") res = [] @@ -124,10 +97,21 @@ def _get_changes(self) -> pd.DataFrame: _df = self._read_change_from_url(_url) res.append(_df) logger.info("get companies changes finish") - return pd.concat(res) + return pd.concat(res, sort=False) @staticmethod - def normalize_symbol(symbol): + def normalize_symbol(symbol: str) -> str: + """ + + Parameters + ---------- + symbol: str + symbol + + Returns + ------- + symbol + """ symbol = f"{int(symbol):06}" return f"SH{symbol}" if symbol.startswith("60") else f"SZ{symbol}" @@ -141,7 +125,14 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: Returns ------- - + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] """ resp = requests.get(url) _text = resp.text @@ -151,8 +142,8 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: add_date = pd.Timestamp("-".join(date_list[0])) else: _date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0])) - add_date = self._get_trading_date_by_shift(_date, shift=0) - remove_date = self._get_trading_date_by_shift(add_date, shift=-1) + add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0) + remove_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=-1) logger.info(f"get {add_date} changes") try: excel_url = re.findall('.*href="(.*?xls.*?)".*', _text)[0] @@ -168,12 +159,12 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: _df = df_map[_s_name] _df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]] _df = _df.applymap(self.normalize_symbol) - _df.columns = ["symbol"] + _df.columns = [self.SYMBOL_FIELD_NAME] _df["type"] = _type - _df["date"] = _date + _df[self.DATE_FIELD_NAME] = _date tmp.append(_df) df = pd.concat(tmp) - except Exception: + except Exception as e: df = None _tmp_count = 0 for _df in pd.read_html(resp.content): @@ -188,9 +179,9 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: (_df.iloc[2:, 2], self.ADD, add_date), ]: _tmp_df = pd.DataFrame() - _tmp_df["symbol"] = _s.map(self.normalize_symbol) + _tmp_df[self.SYMBOL_FIELD_NAME] = _s.map(self.normalize_symbol) _tmp_df["type"] = _type - _tmp_df["date"] = _date + _tmp_df[self.DATE_FIELD_NAME] = _date tmp.append(_tmp_df) df = pd.concat(tmp) df.to_csv( @@ -203,20 +194,33 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: break return df - def _get_change_notices_url(self) -> list: + def _get_change_notices_url(self) -> List[str]: """get change notices url Returns ------- - + [url1, url2] """ resp = requests.get(self.changes_url) html = etree.HTML(resp.text) return html.xpath("//*[@id='itemContainer']//li/a/@href") - def _get_new_companies(self): + def get_new_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 - logger.info("get new companies") + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + logger.info("get new companies......") context = requests.get(self.new_companies_url).content with self.cache_dir.joinpath( f"{self.index_name.lower()}_new_companies.{self.new_companies_url.split('.')[-1]}" @@ -225,51 +229,19 @@ def _get_new_companies(self): _io = BytesIO(context) df = pd.read_excel(_io) df = df.iloc[:, [0, 4]] - df.columns = ["end_date", "symbol"] - df["symbol"] = df["symbol"].map(self.normalize_symbol) - df["end_date"] = pd.to_datetime(df["end_date"]) - df["start_date"] = self.bench_start_date + df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME] + df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol) + df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD]) + df[self.START_DATE_FIELD] = self.bench_start_date + logger.info("end of get new companies.") return df - def parse_instruments(self): - """parse csi300.txt - - Examples - ------- - $ python collector.py parse_instruments --qlib_dir ~/.qlib/qlib_data/cn_data - """ - logger.info(f"start parse {self.index_name.lower()} companies.....") - instruments_columns = ["symbol", "start_date", "end_date"] - changers_df = self._get_changes() - new_df = self._get_new_companies() - logger.info("parse history companies by changes......") - for _row in changers_df.sort_values("date", ascending=False).itertuples(index=False): - if _row.type == self.ADD: - min_end_date = new_df.loc[new_df["symbol"] == _row.symbol, "end_date"].min() - new_df.loc[ - (new_df["end_date"] == min_end_date) & (new_df["symbol"] == _row.symbol), "start_date" - ] = _row.date - else: - _tmp_df = pd.DataFrame( - [[_row.symbol, self.bench_start_date, _row.date]], columns=["symbol", "start_date", "end_date"] - ) - new_df = new_df.append(_tmp_df, sort=False) - - new_df.loc[:, instruments_columns].to_csv( - self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None - ) - logger.info(f"parse {self.index_name.lower()} companies finished.") - class CSI300(CSIIndex): @property def index_code(self): return "000300" - @property - def index_name(self): - return "csi300" - @property def bench_start_date(self) -> pd.Timestamp: return pd.Timestamp("2005-01-01") @@ -284,10 +256,6 @@ class CSI100(CSIIndex): def index_code(self): return "000903" - @property - def index_name(self): - return "csi100" - @property def bench_start_date(self) -> pd.Timestamp: return pd.Timestamp("2006-05-29") @@ -297,19 +265,39 @@ def html_table_index(self): return 1 -def parse_instruments(qlib_dir: str): +def get_instruments( + qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 +): """ Parameters ---------- qlib_dir: str qlib data dir, default "Path(__file__).parent/qlib_data" + index_name: str + index name, value from ["csi100", "csi300"] + method: str + method, value from ["parse_instruments", "save_new_companies"] + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + + Examples + ------- + # parse instruments + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + + # parse new companies + $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + """ - qlib_dir = Path(qlib_dir).expanduser().resolve() - qlib_dir.mkdir(exist_ok=True, parents=True) - CSI300(qlib_dir).parse_instruments() - CSI100(qlib_dir).parse_instruments() + _cur_module = importlib.import_module("collector") + obj = getattr(_cur_module, f"{index_name.upper()}")( + qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep + ) + getattr(obj, method)() if __name__ == "__main__": - fire.Fire(parse_instruments) + fire.Fire(get_instruments) diff --git a/scripts/data_collector/csi/requirements.txt b/scripts/data_collector/cn_index/requirements.txt similarity index 100% rename from scripts/data_collector/csi/requirements.txt rename to scripts/data_collector/cn_index/requirements.txt diff --git a/scripts/data_collector/csi/README.md b/scripts/data_collector/csi/README.md deleted file mode 100644 index 52100df81c..0000000000 --- a/scripts/data_collector/csi/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# CSI300 History Companies Collection - -## Requirements - -```bash -pip install -r requirements.txt -``` - -## Collector Data - -```bash -python collector.py parse_instruments --qlib_dir ~/.qlib/stock_data/qlib_data -``` - diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py new file mode 100644 index 0000000000..c5f3854fdc --- /dev/null +++ b/scripts/data_collector/index.py @@ -0,0 +1,202 @@ +import sys +import abc +from pathlib import Path +from typing import List + +import pandas as pd +from tqdm import tqdm +from loguru import logger + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent)) + + +from data_collector.utils import get_trading_date_by_shift + + +class IndexBase: + DEFAULT_END_DATE = pd.Timestamp("2099-12-31") + SYMBOL_FIELD_NAME = "symbol" + DATE_FIELD_NAME = "date" + START_DATE_FIELD = "start_date" + END_DATE_FIELD = "end_ate" + CHANGE_TYPE_FIELD = "type" + INSTRUMENTS_COLUMNS = [SYMBOL_FIELD_NAME, START_DATE_FIELD, END_DATE_FIELD] + REMOVE = "remove" + ADD = "add" + + def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): + """ + + Parameters + ---------- + index_name: str + index name + qlib_dir: str + qlib directory, by default Path(__file__).resolve().parent.joinpath("qlib_data") + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + """ + self.index_name = index_name + if qlib_dir is None: + qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data") + self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments") + self.instruments_dir.mkdir(exist_ok=True, parents=True) + self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve() + self.cache_dir.mkdir(exist_ok=True, parents=True) + self._request_retry = request_retry + self._retry_sleep = retry_sleep + + @property + @abc.abstractmethod + def bench_start_date(self) -> pd.Timestamp: + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") + + @property + @abc.abstractmethod + def calendar_list(self) -> List[pd.Timestamp]: + """get history trading date + + Returns + ------- + calendar list + """ + raise NotImplementedError("rewrite calendar_list") + + @abc.abstractmethod + def get_new_companies(self) -> pd.DataFrame: + """ + + Returns + ------- + pd.DataFrame: + + symbol start_date end_date + SH600000 2000-01-01 2099-12-31 + + dtypes: + symbol: str + start_date: pd.Timestamp + end_date: pd.Timestamp + """ + raise NotImplementedError("rewrite get_new_companies") + + @abc.abstractmethod + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Returns + ------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + raise NotImplementedError("rewrite get_changes") + + def save_new_companies(self): + """save new companies + + Examples + ------- + $ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data + """ + df = self.get_new_companies() + df = df.drop_duplicates([self.SYMBOL_FIELD_NAME]) + df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv( + self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None + ) + + def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame: + """get changes with history companies + + Parameters + ---------- + history_companies : pd.DataFrame + symbol date + SH600000 2020-11-11 + + dtypes: + symbol: str + date: pd.Timestamp + + Return + -------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + + """ + logger.info("parse changes from history companies......") + last_code = [] + result_df_list = [] + _columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD] + for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)): + _currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][ + self.SYMBOL_FIELD_NAME + ].tolist() + if last_code: + add_code = list(set(last_code) - set(_currenet_code)) + remote_code = list(set(_currenet_code) - set(last_code)) + for _code in add_code: + result_df_list.append( + pd.DataFrame( + [[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]], + columns=_columns, + ) + ) + for _code in remote_code: + result_df_list.append( + pd.DataFrame( + [[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]], + columns=_columns, + ) + ) + last_code = _currenet_code + df = pd.concat(result_df_list) + logger.info("end of parse changes from history companies.") + return df + + def parse_instruments(self): + """parse instruments, eg: csi300.txt + + Examples + ------- + $ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data + """ + logger.info(f"start parse {self.index_name.lower()} companies.....") + instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD] + changers_df = self.get_changes() + new_df = self.get_new_companies().copy() + logger.info("parse history companies by changes......") + for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)): + if _row.type == self.ADD: + min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min() + new_df.loc[ + (new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol), + self.START_DATE_FIELD, + ] = _row.date + else: + _tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns) + new_df = new_df.append(_tmp_df, sort=False) + + new_df.loc[:, instruments_columns].to_csv( + self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None + ) + logger.info(f"parse {self.index_name.lower()} companies finished.") diff --git a/scripts/data_collector/us_index/README.md b/scripts/data_collector/us_index/README.md new file mode 100644 index 0000000000..99a0a09c36 --- /dev/null +++ b/scripts/data_collector/us_index/README.md @@ -0,0 +1,22 @@ +# NASDAQ100/SP500/SP400/DJIA History Companies Collection + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + +# parse new companies +python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + +# index_name support: SP500, NASDAQ100, DJIA, SP400 +# help +python collector.py --help +``` + diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py new file mode 100644 index 0000000000..ea1e974a0d --- /dev/null +++ b/scripts/data_collector/us_index/collector.py @@ -0,0 +1,278 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import abc +import sys +import importlib +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +from typing import List + +import fire +import requests +import pandas as pd +from tqdm import tqdm +from loguru import logger + + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + +from data_collector.index import IndexBase +from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift + + +WIKI_URL = "https://en.wikipedia.org/wiki" + +WIKI_INDEX_NAME_MAP = { + "NASDAQ100": "NASDAQ-100", + "SP500": "List_of_S%26P_500_companies", + "SP400": "List_of_S%26P_400_companies", + "DJIA": "Dow_Jones_Industrial_Average", +} + + +class WIKIIndex(IndexBase): + def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3): + super(WIKIIndex, self).__init__( + index_name=index_name, qlib_dir=qlib_dir, request_retry=request_retry, retry_sleep=retry_sleep + ) + + self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}" + + @property + @abc.abstractmethod + def bench_start_date(self) -> pd.Timestamp: + """ + Returns + ------- + index start date + """ + raise NotImplementedError("rewrite bench_start_date") + + @abc.abstractmethod + def get_changes(self) -> pd.DataFrame: + """get companies changes + + Returns + ------- + pd.DataFrame: + symbol date type + SH600000 2019-11-11 add + SH600000 2020-11-10 remove + dtypes: + symbol: str + date: pd.Timestamp + type: str, value from ["add", "remove"] + """ + raise NotImplementedError("rewrite get_changes") + + @property + def calendar_list(self) -> List[pd.Timestamp]: + """get history trading date + + Returns + ------- + calendar list + """ + _calendar_list = getattr(self, "_calendar_list", None) + if _calendar_list is None: + _calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL"))) + setattr(self, "_calendar_list", _calendar_list) + return _calendar_list + + def _request_new_companies(self) -> requests.Response: + resp = requests.get(self._target_url) + if resp.status_code != 200: + raise ValueError(f"request error: {self._target_url}") + + return resp + + def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame: + _df = df.copy() + _df[self.SYMBOL_FIELD_NAME] = _df[self.SYMBOL_FIELD_NAME].str.strip() + _df[self.START_DATE_FIELD] = self.bench_start_date + _df[self.END_DATE_FIELD] = self.DEFAULT_END_DATE + return _df.loc[:, self.INSTRUMENTS_COLUMNS] + + def get_new_companies(self): + logger.info(f"get new companies {self.index_name} ......") + _data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)() + df_list = pd.read_html(_data.text) + for _df in df_list: + _df = self.filter_df(_df) + if (_df is not None) and (not _df.empty): + _df.columns = [self.SYMBOL_FIELD_NAME] + _df = self.set_default_date_range(_df) + logger.info(f"end of get new companies {self.index_name} ......") + return _df + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + raise NotImplementedError("rewrite filter_df") + + +class NASDAQ100Index(WIKIIndex): + + HISTORY_COMPANIES_URL = ( + "https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD" + ) + MAX_WORKERS = 16 + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if not (set(df.columns) - {"Company", "Ticker"}): + return df.loc[:, ["Ticker"]].copy() + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2003-01-02") + + @deco_retry + def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame: + trade_date = trade_date.strftime("%Y-%m-%d") + cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl") + if cache_path.exists() and use_cache: + df = pd.read_pickle(cache_path) + else: + url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date) + resp = requests.post(url) + if resp.status_code != 200: + raise ValueError(f"request error: {url}") + df = pd.DataFrame(resp.json()["aaData"]) + df[self.DATE_FIELD_NAME] = trade_date + df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True) + if not df.empty: + df.to_pickle(cache_path) + return df + + def get_history_companies(self): + logger.info(f"start get history companies......") + all_history = [] + error_list = [] + with tqdm(total=len(self.calendar_list)) as p_bar: + with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor: + for _trading_date, _df in zip( + self.calendar_list, executor.map(self._request_history_companies, self.calendar_list) + ): + if _df.empty: + error_list.append(_trading_date) + else: + all_history.append(_df) + p_bar.update() + + if error_list: + logger.warning(f"get error: {error_list}") + logger.info(f"total {len(self.calendar_list)}, error {len(error_list)}") + logger.info(f"end of get history companies.") + return pd.concat(all_history, sort=False) + + def get_changes(self): + return self.get_changes_with_history_companies(self.get_history_companies()) + + +class DJIAIndex(WIKIIndex): + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2000-01-01") + + def get_changes(self) -> pd.DataFrame: + pass + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Symbol" in df.columns: + _df = df.loc[:, ["Symbol"]].copy() + _df["Symbol"] = _df["Symbol"].apply(lambda x: x.split(":")[-1]) + return _df + + def parse_instruments(self): + logger.warning(f"No suitable data source has been found!") + + +class SP500Index(WIKIIndex): + WIKISP500_CHANGES_URL = "https://en.wikipedia.org/wiki/List_of_S%26P_500_companies" + + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("1999-01-01") + + def get_changes(self) -> pd.DataFrame: + logger.info(f"get sp500 history changes......") + # NOTE: may update the index of the table + changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1] + changes_df = changes_df.iloc[:, [0, 1, 3]] + changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE] + changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME]) + _result = [] + for _type in [self.ADD, self.REMOVE]: + _df = changes_df.copy() + _df[self.CHANGE_TYPE_FIELD] = _type + _df[self.SYMBOL_FIELD_NAME] = _df[_type] + _df.dropna(subset=[self.SYMBOL_FIELD_NAME], inplace=True) + if _type == self.ADD: + _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply( + lambda x: get_trading_date_by_shift(self.calendar_list, x, 0) + ) + else: + _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply( + lambda x: get_trading_date_by_shift(self.calendar_list, x, -1) + ) + _result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]]) + logger.info(f"end of get sp500 history changes.") + return pd.concat(_result, sort=False) + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Symbol" in df.columns: + return df.loc[:, ["Symbol"]].copy() + + +class SP400Index(WIKIIndex): + @property + def bench_start_date(self) -> pd.Timestamp: + return pd.Timestamp("2000-01-01") + + def get_changes(self) -> pd.DataFrame: + pass + + def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: + if "Ticker symbol" in df.columns: + return df.loc[:, ["Ticker symbol"]].copy() + + def parse_instruments(self): + logger.warning(f"No suitable data source has been found!") + + +def get_instruments( + qlib_dir: str, index_name: str, method: str = "parse_instruments", request_retry: int = 5, retry_sleep: int = 3 +): + """ + + Parameters + ---------- + qlib_dir: str + qlib data dir, default "Path(__file__).parent/qlib_data" + index_name: str + index name, value from ["SP500", "NASDAQ100", "DJIA", "SP400"] + method: str + method, value from ["parse_instruments", "save_new_companies"] + request_retry: int + request retry, by default 5 + retry_sleep: int + request sleep, by default 3 + + Examples + ------- + # parse instruments + $ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments + + # parse new companies + $ python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies + + """ + _cur_module = importlib.import_module("collector") + obj = getattr(_cur_module, f"{index_name.upper()}Index")( + qlib_dir=qlib_dir, index_name=index_name, request_retry=request_retry, retry_sleep=retry_sleep + ) + getattr(obj, method)() + + +if __name__ == "__main__": + fire.Fire(get_instruments) diff --git a/scripts/data_collector/us_index/requirements.txt b/scripts/data_collector/us_index/requirements.txt new file mode 100644 index 0000000000..7292710384 --- /dev/null +++ b/scripts/data_collector/us_index/requirements.txt @@ -0,0 +1,6 @@ +logure +fire +requests +pandas +lxml +loguru diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 08fef7ec9f..8555696428 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -3,56 +3,69 @@ import re import time +import bisect import pickle import requests +import functools from pathlib import Path import pandas as pd from lxml import etree +from loguru import logger +from yahooquery import Ticker -SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" -CSI300_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.000300&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" -SH600000_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.600000&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" -CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" +CALENDAR_URL_BASE = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid={market}.{bench_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20991231" CALENDAR_BENCH_URL_MAP = { - "CSI300": CALENDAR_URL_BASE.format(bench_code="000300"), - "CSI100": CALENDAR_URL_BASE.format(bench_code="000903"), + "CSI300": CALENDAR_URL_BASE.format(market=1, bench_code="000300"), + "CSI100": CALENDAR_URL_BASE.format(market=1, bench_code="000903"), # NOTE: Use the time series of SH600000 as the sequence of all stocks - "ALL": CALENDAR_URL_BASE.format(bench_code="600000"), + "ALL": CALENDAR_URL_BASE.format(market=1, bench_code="000905"), + # NOTE: Use the time series of ^GSPC(SP500) as the sequence of all stocks + "US_ALL": "^GSPC", } + _BENCH_CALENDAR_LIST = None _ALL_CALENDAR_LIST = None _HS_SYMBOLS = None +_US_SYMBOLS = None _CALENDAR_MAP = {} # NOTE: Until 2020-10-20 20:00:00 MINIMUM_SYMBOLS_NUM = 3900 -def get_hs_calendar_list(bench_code="CSI300") -> list: +def get_calendar_list(bench_code="CSI300") -> list: """get SH/SZ history calendar list Parameters ---------- bench_code: str - value from ["CSI300", "CSI500", "ALL"] + value from ["CSI300", "CSI500", "ALL", "US_ALL"] Returns ------- history calendar list """ + logger.info(f"get calendar list: {bench_code}......") + def _get_calendar(url): _value_list = requests.get(url).json()["data"]["klines"] return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) + if bench_code.startswith("US_"): + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") + calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + else: + calendar = _get_calendar(CALENDAR_BENCH_URL_MAP[bench_code]) _CALENDAR_MAP[bench_code] = calendar + logger.info(f"end of get calendar list: {bench_code}.") return calendar @@ -68,13 +81,14 @@ def get_hs_stock_symbols() -> list: def _get_symbol(): _res = set() for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")): - resp = requests.get(SYMBOLS_URL.format(s_type=_k)) + resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k)) _res |= set( map( lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), ) ) + time.sleep(3) return _res if _HS_SYMBOLS is None: @@ -99,6 +113,84 @@ def _get_symbol(): return _HS_SYMBOLS +def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: + """get US stock symbols + + Returns + ------- + stock symbols + """ + global _US_SYMBOLS + + @deco_retry + def _get_eastmoney(): + url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12" + resp = requests.get(url) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + except Exception as e: + logger.warning(f"request error: {e}") + raise + if len(_symbols) < 8000: + raise ValueError("request error") + return _symbols + + @deco_retry + def _get_nasdaq(): + _res_symbols = [] + for _name in ["otherlisted", "nasdaqtraded"]: + url = f"ftp://ftp.nasdaqtrader.com/SymbolDirectory/{_name}.txt" + df = pd.read_csv(url, sep="|") + df = df.rename(columns={"ACT Symbol": "Symbol"}) + _symbols = df["Symbol"].dropna() + _symbols = _symbols.str.replace("$", "-P", regex=False) + _symbols = _symbols.str.replace(".W", "-WT", regex=False) + _symbols = _symbols.str.replace(".U", "-UN", regex=False) + _symbols = _symbols.str.replace(".R", "-RI", regex=False) + _symbols = _symbols.str.replace(".", "-", regex=False) + _res_symbols += _symbols.unique().tolist() + return _res_symbols + + @deco_retry + def _get_nyse(): + url = "https://www.nyse.com/api/quotes/filter" + _parms = { + "instrumentType": "EQUITY", + "pageNumber": 1, + "sortColumn": "NORMALIZED_TICKER", + "sortOrder": "ASC", + "maxResultsPerPage": 10000, + "filterToken": "", + } + resp = requests.post(url, json=_parms) + if resp.status_code != 200: + raise ValueError("request error") + try: + _symbols = [_v["symbolTicker"].replace("-", "-P") for _v in resp.json()] + except Exception as e: + logger.warning(f"request error: {e}") + _symbols = [] + return _symbols + + if _US_SYMBOLS is None: + _all_symbols = _get_eastmoney() + _get_nasdaq() + _get_nyse() + if qlib_data_path is not None: + for _index in ["nasdaq100", "sp500"]: + ins_df = pd.read_csv( + Path(qlib_data_path).joinpath(f"instruments/{_index}.txt"), + sep="\t", + names=["symbol", "start_date", "end_date"], + ) + _all_symbols += ins_df["symbol"].unique().tolist() + _US_SYMBOLS = sorted( + set(map(lambda x: x.replace(".", "-"), filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols))) + ) + + return _US_SYMBOLS + + def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str: """symbol suffix to prefix @@ -137,5 +229,52 @@ def symbol_prefix_to_sufix(symbol: str, capital: bool = True) -> str: return res.upper() if capital else res.lower() +def deco_retry(retry: int = 5, retry_sleep: int = 3): + def deco_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _retry = 5 if callable(retry) else retry + _result = None + for _i in range(1, _retry + 1): + try: + _result = func(*args, **kwargs) + break + except Exception as e: + logger.warning(f"{func.__name__}: {_i} :{e}") + if _i == _retry: + raise + time.sleep(retry_sleep) + return _result + + return wrapper + + return deco_func(retry) if callable(retry) else deco_func + + +def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): + """get trading date by shift + + Parameters + ---------- + trading_list: list + trading calendar list + shift : int + shift, default is 1 + + trading_date : pd.Timestamp + trading date + Returns + ------- + + """ + trading_date = pd.Timestamp(trading_date) + left_index = bisect.bisect_left(trading_list, trading_date) + try: + res = trading_list[left_index + shift] + except IndexError: + res = trading_date + return res + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index 4f1f4c650d..1e65aeaed2 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -18,31 +18,29 @@ pip install -r requirements.txt ## Collector Data -### Download data -> Normalize data -> Dump data +### Download data and Normalize data ```bash -python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data +python collector.py collector_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d ``` -### Download Data From Yahoo Finance +### Download Data ```bash -python collector.py download_data --source_dir ~/.qlib/stock_data/source +python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d ``` -### Normalize Yahoo Finance Data +### Normalize Data ```bash -python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize +python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN ``` -### Manual Ajust Yahoo Finance Data - +### Help ```bash -python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize +pythono collector.py collector_data --help ``` -### Dump Yahoo Finance Data +## Parameters -```bash -python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data -``` +- interval: 1m or 1d +- region: CN or US diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 9456c6bc39..69c7f8f15a 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -1,8 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import abc import sys +import copy import time +import datetime +import importlib from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed @@ -13,33 +17,103 @@ from tqdm import tqdm from loguru import logger from yahooquery import Ticker +from dateutil.tz import tzlocal CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from dump_bin import DumpData -from data_collector.utils import get_hs_calendar_list as get_calendar_list, get_hs_stock_symbols +from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols -INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg=19900101&end=20220101" -MIN_NUMBERS_TRADING = 252 / 4 +INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" +REGION_CN = "CN" +REGION_US = "US" class YahooCollector: - def __init__(self, save_dir: [str, Path], max_workers=4, asynchronous=False, max_collector_count=5, delay=0): + START_DATETIME = pd.Timestamp("2000-01-01") + HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 5)) + END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)) + + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="1d", + max_workers=4, + max_collector_count=5, + delay=0, + check_data_length: bool = False, + limit_nums: int = None, + ): + """ + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None + """ self.save_dir = Path(save_dir).expanduser().resolve() self.save_dir.mkdir(parents=True, exist_ok=True) self._delay = delay - self._stock_list = None + self.stock_list = sorted(set(self.get_stock_list())) + if limit_nums is not None: + try: + self.stock_list = self.stock_list[: int(limit_nums)] + except Exception as e: + logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") self.max_workers = max_workers - self._asynchronous = asynchronous self._max_collector_count = max_collector_count self._mini_symbol_map = {} + self._interval = interval + self._check_small_data = check_data_length + self._start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME + self._end_datetime = pd.Timestamp(str(end)) if end else self.END_DATETIME + if self._interval == "1m": + self._start_datetime = max(self._start_datetime, self.HIGH_FREQ_START_DATETIME) + elif self._interval == "1d": + self._start_datetime = max(self._start_datetime, self.START_DATETIME) + else: + raise ValueError(f"interval error: {self._interval}") + + self._start_datetime = self.convert_datetime(self._start_datetime) + self._end_datetime = self.convert_datetime(min(self._end_datetime, self.END_DATETIME)) + + @property + @abc.abstractmethod + def min_numbers_trading(self): + # daily, one year: 252 / 4 + # us 1min, a week: 6.5 * 60 * 5 + # cn 1min, a week: 4 * 60 * 5 + raise NotImplementedError("rewirte min_numbers_trading") + + @abc.abstractmethod + def get_stock_list(self): + raise NotImplementedError("rewirte get_stock_list") @property - def stock_list(self): - if self._stock_list is None: - self._stock_list = get_hs_stock_symbols() - return self._stock_list + @abc.abstractclassmethod + def _timezone(self): + raise NotImplementedError("rewrite get_timezone") + + def convert_datetime(self, dt: pd.Timestamp): + dt = pd.Timestamp(dt, tz=self._timezone).timestamp() + return pd.Timestamp(dt, tz=tzlocal(), unit="s") def _sleep(self): time.sleep(self._delay) @@ -57,63 +131,95 @@ def save_stock(self, symbol, df: pd.DataFrame): if df.empty: raise ValueError("df is empty") - symbol_s = symbol.split(".") - symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" + symbol = self.normalize_symbol(symbol) stock_path = self.save_dir.joinpath(f"{symbol}.csv") df["symbol"] = symbol - df.to_csv(stock_path, index=False) + if stock_path.exists(): + with stock_path.open("a") as fp: + df.to_csv(fp, index=False, header=None) + else: + with stock_path.open("w") as fp: + df.to_csv(fp, index=False) - def _temp_save_small_data(self, symbol, df): - if len(df) <= MIN_NUMBERS_TRADING: - logger.warning(f"the number of trading days of {symbol} is less than {MIN_NUMBERS_TRADING}!") + def _save_small_data(self, symbol, df): + if len(df) <= self.min_numbers_trading: + logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!") _temp = self._mini_symbol_map.setdefault(symbol, []) _temp.append(df.copy()) + return None else: if symbol in self._mini_symbol_map: self._mini_symbol_map.pop(symbol) + return symbol + + def _get_from_remote(self, symbol): + def _get_simple(start_, end_): + self._sleep() + try: + _resp = Ticker(symbol, asynchronous=False).history(interval=self._interval, start=start_, end=end_) + if isinstance(_resp, pd.DataFrame): + return _resp.reset_index() + else: + logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{_resp}") + except Exception as e: + logger.warning(f"{symbol}-{self._interval}-{start_}-{end_}:{e}") + + _result = None + if self._interval == "1d": + _result = _get_simple(self._start_datetime, self._end_datetime) + elif self._interval == "1m": + _start_date = self._start_datetime.date() + pd.Timedelta(days=1) + _end_date = self._end_datetime.date() + if _start_date >= _end_date: + _result = _get_simple(self._start_datetime, self._end_datetime) + else: + _res = [] + + def _get_multi(start_, end_): + _resp = _get_simple(start_, end_) + if _resp is not None: + _res.append(_resp) + + for _s, _e in ((self._start_datetime, _start_date), (_end_date, self._end_datetime)): + _get_multi(_s, _e) + for _start in pd.date_range(_start_date, _end_date, closed="left"): + _end = _start + pd.Timedelta(days=1) + self._sleep() + _get_multi(_start, _end) + if _res: + _result = pd.concat(_res, sort=False).sort_values(["symbol", "date"]) + else: + raise ValueError(f"cannot support {self._interval}") + return _result + + def _get_data(self, symbol): + _result = None + df = self._get_from_remote(symbol) + if isinstance(df, pd.DataFrame): + if not df.empty: + if self._check_small_data: + if self._save_small_data(symbol, df) is not None: + _result = symbol + self.save_stock(symbol, df) + else: + _result = symbol + self.save_stock(symbol, df) + return _result def _collector(self, stock_list): error_symbol = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - futures = {} - p_bar = tqdm(total=len(stock_list)) - for symbols in [stock_list[i : i + self.max_workers] for i in range(0, len(stock_list), self.max_workers)]: - self._sleep() - resp = Ticker(symbols, asynchronous=self._asynchronous, max_workers=self.max_workers).history( - period="max" - ) - if isinstance(resp, dict): - for symbol, df in resp.items(): - if isinstance(df, pd.DataFrame): - self._temp_save_small_data(self, df) - futures[ - worker.submit( - self.save_stock, symbol, df.reset_index().rename(columns={"index": "date"}) - ) - ] = symbol - else: - error_symbol.append(symbol) - else: - for symbol, df in resp.reset_index().groupby("symbol"): - self._temp_save_small_data(self, df) - futures[worker.submit(self.save_stock, symbol, df)] = symbol - p_bar.update(self.max_workers) - p_bar.close() - - with tqdm(total=len(futures.values())) as p_bar: - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logger.error(e) - error_symbol.append(futures[future]) + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with tqdm(total=len(stock_list)) as p_bar: + for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)): + if _result is None: + error_symbol.append(_symbol) p_bar.update() print(error_symbol) logger.info(f"error symbol nums: {len(error_symbol)}") logger.info(f"current get symbol nums: {len(stock_list)}") error_symbol.extend(self._mini_symbol_map.keys()) - return error_symbol + return sorted(set(error_symbol)) def collector_data(self): """collector data""" @@ -126,81 +232,140 @@ def collector_data(self): stock_list = self._collector(stock_list) logger.info(f"{i+1} finish.") for _symbol, _df_list in self._mini_symbol_map.items(): - self.save_stock(_symbol, max(_df_list, key=len)) - - logger.warning(f"less than {MIN_NUMBERS_TRADING} stock list: {list(self._mini_symbol_map.keys())}") + self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])) + if self._mini_symbol_map: + logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}") + logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}") self.download_index_data() + @abc.abstractmethod + def download_index_data(self): + """download index data""" + raise NotImplementedError("rewrite download_index_data") + + @abc.abstractmethod + def normalize_symbol(self, symbol: str): + """normalize symbol""" + raise NotImplementedError("rewrite normalize_symbol") + + +class YahooCollectorCN(YahooCollector): + @property + def min_numbers_trading(self): + if self._interval == "1m": + return 60 * 4 * 5 + elif self._interval == "1d": + return 252 / 4 + + def get_stock_list(self): + logger.info("get HS stock symbos......") + symbols = get_hs_stock_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + def download_index_data(self): # TODO: from MSN - for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): - logger.info(f"get bench data: {_index_name}({_index_code})......") - df = pd.DataFrame( - map( - lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code)).json()["data"]["klines"], - ) - ) - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] - df["date"] = pd.to_datetime(df["date"]) - df = df.astype(float, errors="ignore") - df["adjclose"] = df["close"] - df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) + # FIXME: 1m + if self._interval == "1d": + _format = "%Y%m%d" + _begin = self._start_datetime.strftime(_format) + _end = (self._end_datetime + pd.Timedelta(days=-1)).strftime(_format) + for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items(): + logger.info(f"get bench data: {_index_name}({_index_code})......") + try: + df = pd.DataFrame( + map( + lambda x: x.split(","), + requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ + "data" + ]["klines"], + ) + ) + except Exception as e: + logger.warning(f"get {_index_name} error: {e}") + continue + df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df["date"] = pd.to_datetime(df["date"]) + df = df.astype(float, errors="ignore") + df["adjclose"] = df["close"] + df.to_csv(self.save_dir.joinpath(f"sh{_index_code}.csv"), index=False) + else: + logger.warning(f"{self.__class__.__name__} {self._interval} does not support: downlaod_index_data") + def normalize_symbol(self, symbol): + symbol_s = symbol.split(".") + symbol = f"sh{symbol_s[0]}" if symbol_s[-1] == "ss" else f"sz{symbol_s[0]}" + return symbol -class Run: - def __init__(self, source_dir=None, normalize_dir=None, qlib_dir=None, max_workers=4): - """ + @property + def _timezone(self): + return "Asia/Shanghai" - Parameters - ---------- - source_dir: str - The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" - qlib_dir: str - qlib data dir; usage of provider_uri, default "Path(__file__).parent/qlib_data" - max_workers: int - Concurrent number, default is 4 - """ - if source_dir is None: - source_dir = CUR_DIR.joinpath("source") - self.source_dir = Path(source_dir).expanduser().resolve() - self.source_dir.mkdir(parents=True, exist_ok=True) - if normalize_dir is None: - normalize_dir = CUR_DIR.joinpath("normalize") - self.normalize_dir = Path(normalize_dir).expanduser().resolve() - self.normalize_dir.mkdir(parents=True, exist_ok=True) +class YahooCollectorUS(YahooCollector): + @property + def min_numbers_trading(self): + if self._interval == "1m": + return 60 * 6.5 * 5 + elif self._interval == "1d": + return 252 / 4 + + def get_stock_list(self): + logger.info("get US stock symbols......") + symbols = get_us_stock_symbols() + [ + "^GSPC", + "^NDX", + "^DJI", + ] + logger.info(f"get {len(symbols)} symbols.") + return symbols - if qlib_dir is None: - qlib_dir = CUR_DIR.joinpath("qlib_data") - self.qlib_dir = Path(qlib_dir).expanduser().resolve() - self.qlib_dir.mkdir(parents=True, exist_ok=True) + def download_index_data(self): + pass - self.max_workers = max_workers + def normalize_symbol(self, symbol): + return symbol.upper() - def normalize_data(self): - """normalize data + @property + def _timezone(self): + return "America/New_York" - Examples - --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize +class YahooNormalize: + COLUMNS = ["open", "close", "high", "low", "volume"] + + def __init__(self, source_dir: [str, Path], target_dir: [str, Path], max_workers: int = 16): """ - def _normalize(file_path: Path): - columns = ["open", "close", "high", "low", "volume"] - df = pd.read_csv(file_path) + Parameters + ---------- + source_dir: str or Path + The directory where the raw data collected from the Internet is saved + target_dir: str or Path + Directory for normalize data + max_workers: int + Concurrent number, default is 16 + """ + if not (source_dir and target_dir): + raise ValueError("source_dir and target_dir cannot be None") + self._source_dir = Path(source_dir).expanduser() + self._target_dir = Path(target_dir).expanduser() + self._max_workers = max_workers + self._calendar_list = self._get_calendar_list() + + def normalize_data(self): + logger.info("normalize data......") + + def _normalize(source_path: Path): + columns = copy.deepcopy(self.COLUMNS) + df = pd.read_csv(source_path) df.set_index("date", inplace=True) df.index = pd.to_datetime(df.index) df = df[~df.index.duplicated(keep="first")] - - # using China stock market data calendar - df = df.reindex(pd.Index(get_calendar_list("ALL"))) + if self._calendar_list is not None: + df = df.reindex(pd.DataFrame(index=self._calendar_list).loc[df.index.min() : df.index.max()].index) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), set(df.columns) - {"symbol"}] = np.nan df["factor"] = df["adjclose"] / df["close"] for _col in columns: @@ -213,22 +378,17 @@ def _normalize(file_path: Path): columns += ["change", "factor"] df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan df.index.names = ["date"] - df.loc[:, columns].to_csv(self.normalize_dir.joinpath(file_path.name)) + df.loc[:, columns].to_csv(self._target_dir.joinpath(source_path.name)) - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - file_list = list(self.source_dir.glob("*.csv")) + with ThreadPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._source_dir.glob("*.csv")) with tqdm(total=len(file_list)) as p_bar: for _ in worker.map(_normalize, file_list): p_bar.update() def manual_adj_data(self): - """manual adjust data - - Examples - -------- - $ python collector.py manual_adj_data --normalize_dir ~/.qlib/stock_data/normalize - - """ + """adjust data""" + logger.info("manual adjust data......") def _adj(file_path: Path): df = pd.read_csv(file_path) @@ -244,59 +404,166 @@ def _adj(file_path: Path): df[_col] = df[_col] / _close else: pass - df.reset_index().to_csv(self.normalize_dir.joinpath(file_path.name), index=False) + df.reset_index().to_csv(self._target_dir.joinpath(file_path.name), index=False) - with ThreadPoolExecutor(max_workers=self.max_workers) as worker: - file_list = list(self.normalize_dir.glob("*.csv")) + with ThreadPoolExecutor(max_workers=self._max_workers) as worker: + file_list = list(self._target_dir.glob("*.csv")) with tqdm(total=len(file_list)) as p_bar: for _ in worker.map(_adj, file_list): p_bar.update() - def dump_data(self): - """dump yahoo data + def normalize(self): + self.normalize_data() + self.manual_adj_data() + + @abc.abstractmethod + def _get_calendar_list(self): + """Get benchmark calendar""" + raise NotImplementedError("") - Examples - --------- - $ python collector.py dump_data --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data +class YahooNormalizeUS(YahooNormalize): + def _get_calendar_list(self): + # TODO: from MSN + return get_calendar_list("US_ALL") + + +class YahooNormalizeCN(YahooNormalize): + def _get_calendar_list(self): + # TODO: from MSN + return get_calendar_list("ALL") + + +class Run: + def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN): """ - DumpData(csv_path=self.normalize_dir, qlib_dir=self.qlib_dir, works=self.max_workers).dump( - include_fields="close,open,high,low,volume,change,factor" - ) - def download_data(self, asynchronous=False, max_collector_count=5, delay=0): + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 4 + region: str + region, value from ["CN", "US"], default "CN" + """ + if source_dir is None: + source_dir = CUR_DIR.joinpath("source") + self.source_dir = Path(source_dir).expanduser().resolve() + self.source_dir.mkdir(parents=True, exist_ok=True) + + if normalize_dir is None: + normalize_dir = CUR_DIR.joinpath("normalize") + self.normalize_dir = Path(normalize_dir).expanduser().resolve() + self.normalize_dir.mkdir(parents=True, exist_ok=True) + + self._cur_module = importlib.import_module("collector") + self.max_workers = max_workers + self.region = region + + def download_data( + self, + max_collector_count=5, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + ): """download data from Internet + Parameters + ---------- + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None Examples --------- - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source - + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + # get 1m data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - YahooCollector( + + _class = getattr(self._cur_module, f"YahooCollector{self.region.upper()}") + _class( self.source_dir, max_workers=self.max_workers, - asynchronous=asynchronous, max_collector_count=max_collector_count, delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, ).collector_data() - def download_index_data(self): - YahooCollector(self.source_dir).download_index_data() - - def download_bench_data(self): - """download bench stock data(SH000300)""" + def normalize_data(self): + """normalize data - def collector_data(self): - """download -> normalize -> dump data + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN + """ + _class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}") + _class(self.source_dir, self.normalize_dir, self.max_workers).normalize() + + def collector_data( + self, + max_collector_count=5, + delay=0, + start=None, + end=None, + interval="1d", + check_data_length=False, + limit_nums=None, + ): + """download -> normalize + Parameters + ---------- + max_collector_count: int + default 5 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1m, 1d], default 1m + start: str + start datetime, default "2000-01-01" + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))`` + check_data_length: bool + check data length, by default False + limit_nums: int + using for debug, by default None Examples ------- - $ python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize_dir --qlib_dir ~/.qlib/stock_data/qlib_data + python collector.py collector_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - self.download_data() + self.download_data( + max_collector_count=max_collector_count, + delay=delay, + start=start, + end=end, + interval=interval, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) self.normalize_data() - self.manual_adj_data() - self.dump_data() if __name__ == "__main__": diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index d972f6318e..2e44c454ef 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import abc import shutil +import traceback from pathlib import Path +from typing import Iterable, List, Union from functools import partial -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed, ProcessPoolExecutor import fire import numpy as np @@ -13,8 +16,20 @@ from loguru import logger -class DumpData(object): - FILE_SUFFIX = ".csv" +class DumpDataBase: + INSTRUMENTS_START_FIELD = "start_datetime" + INSTRUMENTS_END_FIELD = "end_datetime" + CALENDARS_DIR_NAME = "calendars" + FEATURES_DIR_NAME = "features" + INSTRUMENTS_DIR_NAME = "instruments" + DUMP_FILE_SUFFIX = ".bin" + DAILY_FORMAT = "%Y-%m-%d" + HIGH_FREQ_FORMAT = "%Y-%m-%d %H:%M:%S" + INSTRUMENTS_SEP = "\t" + INSTRUMENTS_FILE_NAME = "all.txt" + + UPDATE_MODE = "update" + ALL_MODE = "all" def __init__( self, @@ -22,8 +37,13 @@ def __init__( qlib_dir: str, backup_dir: str = None, freq: str = "day", - works: int = None, + max_workers: int = 16, date_field_name: str = "date", + file_suffix: str = ".csv", + symbol_field_name: str = "symbol", + exclude_fields: str = "", + include_fields: str = "", + limit_nums: int = None, ): """ @@ -37,80 +57,101 @@ def __init__( if backup_dir is not None, backup qlib_dir to backup_dir freq: str, default "day" transaction frequency - works: int, default None + max_workers: int, default None number of threads date_field_name: str, default "date" the name of the date field in the csv + file_suffix: str, default ".csv" + file suffix + symbol_field_name: str, default "symbol" + symbol field name + include_fields: tuple + dump fields + exclude_fields: tuple + fields not dumped + limit_nums: int + Use when debugging, default None """ csv_path = Path(csv_path).expanduser() - self.csv_files = sorted(csv_path.glob(f"*{self.FILE_SUFFIX}") if csv_path.is_dir() else [csv_path]) + if isinstance(exclude_fields, str): + exclude_fields = exclude_fields.split(",") + if isinstance(include_fields, str): + include_fields = include_fields.split(",") + self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) + self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self.file_suffix = file_suffix + self.symbol_field_name = symbol_field_name + self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) + if limit_nums is not None: + self.csv_files = self.csv_files[: int(limit_nums)] self.qlib_dir = Path(qlib_dir).expanduser() self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() if backup_dir is not None: self._backup_qlib_dir(Path(backup_dir).expanduser()) self.freq = freq - self.calendar_format = "%Y-%m-%d" if self.freq == "day" else "%Y-%m-%d %H:%M:%S" + self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT - self.works = works + self.works = max_workers self.date_field_name = date_field_name - self._calendars_dir = self.qlib_dir.joinpath("calendars") - self._features_dir = self.qlib_dir.joinpath("features") - self._instruments_dir = self.qlib_dir.joinpath("instruments") + self._calendars_dir = self.qlib_dir.joinpath(self.CALENDARS_DIR_NAME) + self._features_dir = self.qlib_dir.joinpath(self.FEATURES_DIR_NAME) + self._instruments_dir = self.qlib_dir.joinpath(self.INSTRUMENTS_DIR_NAME) self._calendars_list = [] - self._include_fields = () - self._exclude_fields = () + + self._mode = self.ALL_MODE + self._kwargs = {} def _backup_qlib_dir(self, target_dir: Path): shutil.copytree(str(self.qlib_dir.resolve()), str(target_dir.resolve())) - def _get_date_for_df(self, file_path: Path, *, is_begin_end: bool = False): - df = pd.read_csv(str(file_path.resolve())) + def _format_datetime(self, datetime_d: [str, pd.Timestamp]): + datetime_d = pd.Timestamp(datetime_d) + return datetime_d.strftime(self.calendar_format) + + def _get_date( + self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False + ) -> Iterable[pd.Timestamp]: + if not isinstance(file_or_df, pd.DataFrame): + df = self._get_source_data(file_or_df) + else: + df = file_or_df if df.empty or self.date_field_name not in df.columns.tolist(): - return [] - if is_begin_end: - return [df[self.date_field_name].min(), df[self.date_field_name].max()] - return df[self.date_field_name].tolist() - - def _get_source_data(self, file_path: Path): - df = pd.read_csv(str(file_path.resolve())) + _calendars = pd.Series() + else: + _calendars = df[self.date_field_name] + + if is_begin_end and as_set: + return (_calendars.min(), _calendars.max()), set(_calendars) + elif is_begin_end: + return _calendars.min(), _calendars.max() + elif as_set: + return set(_calendars) + else: + return _calendars.tolist() + + def _get_source_data(self, file_path: Path) -> pd.DataFrame: + df = pd.read_csv(str(file_path.resolve()), low_memory=False) df[self.date_field_name] = df[self.date_field_name].astype(np.datetime64) + # df.drop_duplicates([self.date_field_name], inplace=True) return df - def _file_to_bin(self, file_path: Path = None): - code = file_path.name[: -len(self.FILE_SUFFIX)].strip().lower() - features_dir = self._features_dir.joinpath(code) - features_dir.mkdir(parents=True, exist_ok=True) - calendars_df = pd.DataFrame(data=self._calendars_list, columns=[self.date_field_name]) - calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64) - # read csv file - df = self._get_source_data(file_path) - cal_df = calendars_df[ - (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) - & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) - ] - cal_df.set_index(self.date_field_name, inplace=True) - df.set_index(self.date_field_name, inplace=True) - r_df = df.reindex(cal_df.index) - date_index = self._calendars_list.index(r_df.index.min()) - for field in ( + def get_symbol_from_file(self, file_path: Path) -> str: + return file_path.name[: -len(self.file_suffix)].strip().lower() + + def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: + return ( self._include_fields if self._include_fields - else set(r_df.columns) - set(self._exclude_fields) + else set(df_columns) - set(self._exclude_fields) if self._exclude_fields - else r_df.columns - ): - - bin_path = features_dir.joinpath(f"{field}.{self.freq}.bin") - if field not in r_df.columns: - continue - r = np.hstack([date_index, r_df[field]]).astype(" List[pd.Timestamp]: return sorted( map( pd.Timestamp, @@ -118,133 +159,305 @@ def _read_calendar(calendar_path: Path): ) ) - def dump_features( - self, - calendar_path: str = None, - include_fields: tuple = None, - exclude_fields: tuple = None, - ): - """dump features - - Parameters - --------- - calendar_path: str - calendar path + def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: + return pd.read_csv( + instrument_path, + sep=self.INSTRUMENTS_SEP, + names=[self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD], + ) - include_fields: str - dump fields + def save_calendars(self, calendars_data: list): + self._calendars_dir.mkdir(parents=True, exist_ok=True) + calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) + result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data)) + np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8") - exclude_fields: str - fields not dumped + def save_instruments(self, instruments_data: Union[list, pd.DataFrame]): + self._instruments_dir.mkdir(parents=True, exist_ok=True) + instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve()) + if isinstance(instruments_data, pd.DataFrame): + instruments_data = instruments_data.loc[:, [self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD]] + instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP) + else: + np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8") + + def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame: + # calendars + calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name]) + calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype(np.datetime64) + cal_df = calendars_df[ + (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) + & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) + ] + # align index + cal_df.set_index(self.date_field_name, inplace=True) + df.set_index(self.date_field_name, inplace=True) + r_df = df.reindex(cal_df.index) + return r_df - Notes - --------- - python dump_bin.py dump_features --csv_path --qlib_dir + @staticmethod + def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int: + return calendar_list.index(df.index.min()) + + def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path): + if df.empty: + logger.warning(f"{features_dir.name} data is None or empty") + return + # align index + _df = self.data_merge_calendar(df, self._calendars_list) + date_index = self.get_datetime_index(_df, calendar_list) + for field in self.get_dump_fields(_df.columns): + bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}") + if field not in _df.columns: + continue + if self._mode == self.UPDATE_MODE: + # update + with bin_path.open("ab") as fp: + np.array(_df[field]).astype(" --qlib_dir + def _dump_instruments(self): + logger.info("start dump instruments......") + self.save_instruments(self._kwargs["date_range_list"]) + logger.info("end of instruments dump.\n") - Examples - --------- - python dump_bin.py dump_calendars --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data - """ - logger.info("start dump calendars......") - calendar_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) - all_datetime = set() + def _dump_features(self): + logger.info("start dump features......") + _dump_func = partial(self._dump_bin, calendar_list=self._calendars_list) with tqdm(total=len(self.csv_files)) as p_bar: - with ThreadPoolExecutor(max_workers=self.works) as executor: - for temp_datetime in executor.map(self._get_date_for_df, self.csv_files): - all_datetime = all_datetime | set(temp_datetime) + with ProcessPoolExecutor(max_workers=self.works) as executor: + for _ in executor.map(_dump_func, self.csv_files): p_bar.update() - self._calendars_list = sorted(map(pd.Timestamp, all_datetime)) - self._calendars_dir.mkdir(parents=True, exist_ok=True) - result_calendar_list = list(map(lambda x: x.strftime(self.calendar_format), self._calendars_list)) - np.savetxt(calendar_path, result_calendar_list, fmt="%s", encoding="utf-8") - logger.info("end of calendars dump.\n") + logger.info("end of features dump.\n") - def dump_instruments(self): - """dump instruments + def dump(self): + self._get_all_date() + self._dump_calendars() + self._dump_instruments() + self._dump_features() - Notes - --------- - python dump_bin.py dump_instruments --csv_path --qlib_dir - Examples - --------- - python dump_bin.py dump_instruments --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data - """ +class DumpDataFix(DumpDataAll): + def _dump_instruments(self): logger.info("start dump instruments......") - symbol_list = list(map(lambda x: x.name[: -len(self.FILE_SUFFIX)], self.csv_files)) - _result_list = [] - _fun = partial(self._get_date_for_df, is_begin_end=True) - with tqdm(total=len(symbol_list)) as p_bar: - with ThreadPoolExecutor(max_workers=self.works) as execute: - for symbol, res in zip(symbol_list, execute.map(_fun, self.csv_files)): - if res: - begin_time = res[0] - end_time = res[-1] - _result_list.append(f"{symbol.upper()}\t{begin_time}\t{end_time}") + _fun = partial(self._get_date, is_begin_end=True) + new_stock_files = sorted(filter(lambda x: x.name not in self._old_instruments, self.csv_files)) + with tqdm(total=len(new_stock_files)) as p_bar: + with ProcessPoolExecutor(max_workers=self.works) as execute: + for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)): + if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp): + symbol = self.get_symbol_from_file(file_path).upper() + _dt_map = self._old_instruments.setdefault(symbol, dict()) + _dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time) + _dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time) p_bar.update() - - self._instruments_dir.mkdir(parents=True, exist_ok=True) - to_path = str(self._instruments_dir.joinpath("all.txt").resolve()) - np.savetxt(to_path, _result_list, fmt="%s", encoding="utf-8") + self.save_instruments(pd.DataFrame.from_dict(self._old_instruments, orient="index")) logger.info("end of instruments dump.\n") - def dump(self, include_fields: str = None, exclude_fields: tuple = None): - """dump data + def dump(self): + self._calendars_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt")) + # noinspection PyAttributeOutsideInit + self._old_instruments = self._read_instruments( + self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME) + ).to_dict( + orient="index" + ) # type: dict + self._dump_instruments() + self._dump_features() + + +class DumpDataUpdate(DumpDataBase): + def __init__( + self, + csv_path: str, + qlib_dir: str, + backup_dir: str = None, + freq: str = "day", + max_workers: int = 16, + date_field_name: str = "date", + file_suffix: str = ".csv", + symbol_field_name: str = "symbol", + exclude_fields: str = "", + include_fields: str = "", + limit_nums: int = None, + ): + """ Parameters ---------- - include_fields: str + csv_path: str + stock data path or directory + qlib_dir: str + qlib(dump) data director + backup_dir: str, default None + if backup_dir is not None, backup qlib_dir to backup_dir + freq: str, default "day" + transaction frequency + max_workers: int, default None + number of threads + date_field_name: str, default "date" + the name of the date field in the csv + file_suffix: str, default ".csv" + file suffix + symbol_field_name: str, default "symbol" + symbol field name + include_fields: tuple dump fields - - exclude_fields: str + exclude_fields: tuple fields not dumped - - Examples - --------- - python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --include_fields open,close,high,low,volume,factor - python dump_bin.py dump --csv_path ~/tmp/stock_data --qlib_dir ~/tmp/qlib_data --exclude_fields date,code,timestamp,code_name + limit_nums: int + Use when debugging, default None """ - if isinstance(exclude_fields, str): - exclude_fields = exclude_fields.split(",") - if isinstance(include_fields, str): - include_fields = include_fields.split(",") - self.dump_calendars() - self.dump_features(include_fields=include_fields, exclude_fields=exclude_fields) - self.dump_instruments() + super().__init__( + csv_path, + qlib_dir, + backup_dir, + freq, + max_workers, + date_field_name, + file_suffix, + symbol_field_name, + exclude_fields, + include_fields, + ) + self._mode = self.UPDATE_MODE + self._old_calendar_list = self._read_calendars(self._calendars_dir.joinpath(f"{self.freq}.txt")) + self._update_instruments = self._read_instruments( + self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME) + ).to_dict( + orient="index" + ) # type: dict + + # load all csv files + self._all_data = self._load_all_source_data() # type: pd.DataFrame + self._update_calendars = sorted( + filter(lambda x: x > self._old_calendar_list[-1], self._all_data[self.date_field_name].unique()) + ) + self._new_calendar_list = self._old_calendar_list + self._update_calendars + + def _load_all_source_data(self): + # NOTE: Need more memory + logger.info("start load all source data....") + all_df = [] + + def _read_csv(file_path: Path): + if self._include_fields: + _df = pd.read_csv(file_path, usecols=self._include_fields) + else: + _df = pd.read_csv(file_path) + if self.symbol_field_name not in _df.columns: + _df[self.symbol_field_name] = self.get_symbol_from_file(file_path) + return _df + + with tqdm(total=len(self.csv_files)) as p_bar: + with ThreadPoolExecutor(max_workers=self.works) as executor: + for df in executor.map(_read_csv, self.csv_files): + if df: + all_df.append(df) + p_bar.update() + + logger.info("end of load all data.\n") + return pd.concat(all_df, sort=False) + + def _dump_calendars(self): + pass + + def _dump_instruments(self): + pass + + def _dump_features(self): + logger.info("start dump features......") + error_code = {} + with ProcessPoolExecutor(max_workers=self.works) as executor: + futures = {} + for _code, _df in self._all_data.groupby(self.symbol_field_name): + _code = str(_code).upper() + _start, _end = self._get_date(_df, is_begin_end=True) + if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)): + continue + if _code in self._update_instruments: + self._update_instruments[_code]["end_time"] = _end + futures[executor.submit(self._dump_bin, _df, self._update_calendars)] = _code + else: + # new stock + _dt_range = self._update_instruments.setdefault(_code, dict()) + _dt_range["start_time"] = _start + _dt_range["end_time"] = _end + futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code + + for _future in tqdm(as_completed(futures)): + try: + _future.result() + except Exception: + error_code[futures[_future]] = traceback.format_exc() + logger.info(f"dump bin errors: {error_code}") + + logger.info("end of features dump.\n") + + def dump(self): + self.save_calendars(self._new_calendar_list) + self._dump_features() + self.save_instruments(pd.DataFrame.from_dict(self._update_instruments, orient="index")) if __name__ == "__main__": - fire.Fire(DumpData) + fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate}) diff --git a/scripts/get_data.py b/scripts/get_data.py index f40bc7d316..661e31c5f3 100644 --- a/scripts/get_data.py +++ b/scripts/get_data.py @@ -55,7 +55,7 @@ def _unzip(file_path: Path, target_dir: Path): for _file in tqdm(zp.namelist()): zp.extract(_file, str(target_dir.resolve())) - def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_data", version="latest"): + def qlib_data(self, name="qlib_data", target_dir="~/.qlib/qlib_data/cn_data", version="latest", interval="1d", region="cn"): """download cn qlib data from remote Parameters @@ -63,18 +63,25 @@ def qlib_data_cn(self, name="qlib_data_cn", target_dir="~/.qlib/qlib_data/cn_dat target_dir: str data save directory name: str - dataset name, value from [qlib_data_cn, qlib_data_cn_simple], by default qlib_data_cn + dataset name, value from [qlib_data, qlib_data_simple], by default qlib_data version: str data version, value from [v0, v1, ..., latest], by default latest + interval: str + data freq, value from [1d], by default 1d + region: str + data region, value from [cn, us], by default cn Examples --------- - python get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data --version latest + python get_data.py qlib_data --name qlib_data --target_dir ~/.qlib/qlib_data/cn_data --version latest --interval 1d --region cn ------- """ - file_name = f"{name}_{version}.zip" - self._download_data(file_name, target_dir) + # TODO: The US stock code contains "PRN", and the directory cannot be created on Windows system + if region.lower() == "us": + logger.warning(f"The US stock code contains 'PRN', and the directory cannot be created on Windows system") + file_name = f"{name}_{region.lower()}_{interval.lower()}_{version}.zip" + self._download_data(file_name.lower(), target_dir) def csv_data_cn(self, target_dir="~/.qlib/csv_data/cn_data"): """download cn csv data from remote diff --git a/tests/dataset_tests/test_dataset.py b/tests/dataset_tests/test_dataset.py index 5a70fee496..9d282b1672 100644 --- a/tests/dataset_tests/test_dataset.py +++ b/tests/dataset_tests/test_dataset.py @@ -18,7 +18,7 @@ def setUpClass(cls) -> None: sys.path.append(str(Path(__file__).resolve().parent.parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri) + GetData().qlib_data(name="qlib_data_simple", target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) def testCSI300(self): diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index ff80f8520d..2930489a29 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -149,7 +149,7 @@ def setUpClass(cls) -> None: sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=provider_uri) + GetData().qlib_data(name="qlib_data_simple", region="cn", version="latest", interval="1d", target_dir=provider_uri) qlib.init(provider_uri=provider_uri, region=REG_CN) def test_0_train(self): diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py index dbf4fb0825..01e6a3758f 100644 --- a/tests/test_dump_data.py +++ b/tests/test_dump_data.py @@ -14,7 +14,7 @@ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from get_data import GetData -from dump_bin import DumpData +from dump_bin import DumpDataAll, DumpDataFix DATA_DIR = Path(__file__).parent.joinpath("test_dump_data") @@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase): @classmethod def setUpClass(cls) -> None: GetData().csv_data_cn(SOURCE_DIR) - TestDumpData.DUMP_DATA = DumpData(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR) + TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS) TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) provider_uri = str(QLIB_DIR.resolve()) qlib.init( @@ -49,8 +49,10 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: shutil.rmtree(str(DATA_DIR.resolve())) - def test_0_dump_calendars(self): - self.DUMP_DATA.dump_calendars() + def test_0_dump_bin(self): + self.DUMP_DATA.dump() + + def test_1_dump_calendars(self): ori_calendars = set( map( pd.Timestamp, @@ -60,23 +62,21 @@ def test_0_dump_calendars(self): res_calendars = set(D.calendar()) assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed" - def test_1_dump_instruments(self): - self.DUMP_DATA.dump_instruments() + def test_2_dump_instruments(self): ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) res_ins = set(D.list_instruments(D.instruments("all"), as_list=True)) assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed" - def test_2_dump_features(self): - self.DUMP_DATA.dump_features(include_fields=self.FIELDS) + def test_3_dump_features(self): df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS) TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :] self.assertFalse(df.dropna().empty, "features data failed") self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed") - def test_3_dump_features_simple(self): + def test_4_dump_features_simple(self): stock = self.STOCK_NAMES[0] - dump_data = DumpData(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR) - dump_data.dump_features(include_fields=self.FIELDS, calendar_path=QLIB_DIR.joinpath("calendars", "day.txt")) + dump_data = DumpDataFix(csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS) + dump_data.dump() df = D.features([stock], self.QLIB_FIELDS) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index d0f5ca591e..732d866dd0 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -37,7 +37,7 @@ def tearDownClass(cls) -> None: def test_0_qlib_data(self): - GetData().qlib_data_cn(name="qlib_data_cn_simple", target_dir=QLIB_DIR) + GetData().qlib_data(name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", version="latest") df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") self.assertFalse(df.dropna().empty, "get qlib data failed")