Skip to content

feat: data improve, support parquet #1966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ Load and prepare data by running the following code:
### Get with module
```bash
# get 1d data
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn

# get 1min data
python -m qlib.run.get_data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min
python -m qlib.cli.data qlib_data --target_dir ~/.qlib/qlib_data/cn_data_1min --region cn --interval 1min

```

Expand Down
2 changes: 1 addition & 1 deletion docs/component/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Supposed that users prepare their CSV format data in the directory ``~/.qlib/csv

.. code-block:: bash

python scripts/dump_bin.py dump_all --csv_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor
python scripts/dump_bin.py dump_all --data_path ~/.qlib/csv_data/my_data --qlib_dir ~/.qlib/qlib_data/my_data --include_fields open,close,high,low,volume,factor

For other supported parameters when dumping the data into `.bin` file, users can refer to the information by running the following commands:

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,4 @@ packages = [
]

[project.scripts]
qrun = "qlib.workflow.cli:run"
qrun = "qlib.cli.run:run"
1 change: 1 addition & 0 deletions qlib/backtest/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def _calc_trade_info_by_order(
# if we don't know current position, we choose to sell all
# Otherwise, we clip the amount based on current position
if position is not None:
# TODO: make the trading shortable
current_amount = (
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion qlib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class QSettings(BaseSettings):
"""

mlflow: MLflowSettings = MLflowSettings()
provider_uri: str = "~/.qlib/qlib_data/cn_data"

model_config = SettingsConfigDict(
env_prefix="QLIB_",
Expand Down Expand Up @@ -261,7 +262,7 @@ def register_from_C(config, skip_register=True):
},
"client": {
# config it in user's own code
"provider_uri": "~/.qlib/qlib_data/cn_data",
"provider_uri": QSETTINGS.provider_uri,
# cache
# Using parameter 'remote' to announce the client is using server_cache, and the writing access will be disabled.
# Disable cache by default. Avoid introduce advanced features for beginners
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_collector/baostock_5min/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,5 @@
- examples:
```bash
# dump 5min cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol
```
2 changes: 1 addition & 1 deletion scripts/data_collector/crypto/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ python collector.py normalize_data --source_dir ~/.qlib/crypto_data/source/1d --

# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps
python dump_bin.py dump_all --data_path ~/.qlib/crypto_data/source/1d_nor --qlib_dir ~/.qlib/qlib_data/crypto_data --freq day --date_field_name date --include_fields prices,total_volumes,market_caps

```

Expand Down
2 changes: 1 addition & 1 deletion scripts/data_collector/fund/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_data

# dump data
cd qlib/scripts
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
python dump_bin.py dump_all --data_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ

```

Expand Down
2 changes: 1 addition & 1 deletion scripts/data_collector/pit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ python collector.py normalize_data --interval quarterly --source_dir ~/.qlib/sto

```bash
cd qlib/scripts
python dump_pit.py dump --csv_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
python dump_pit.py dump --data_path ~/.qlib/stock_data/source/pit_normalized --qlib_dir ~/.qlib/qlib_data/cn_data --interval quarterly
```
4 changes: 2 additions & 2 deletions scripts/data_collector/yahoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ pip install -r requirements.txt
- examples:
```bash
# dump 1d cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_data --freq day --exclude_fields date,symbol
# dump 1min cn
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol
python dump_bin.py dump_all --data_path ~/.qlib/stock_data/source/cn_1min_nor --qlib_dir ~/.qlib/qlib_data/cn_data_1min --freq 1min --exclude_fields date,symbol
```

### Automatic update of daily frequency data(from yahoo finance)
Expand Down
2 changes: 1 addition & 1 deletion scripts/data_collector/yahoo/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def normalize_data_1d_extend(

3. normalize new source data(from step 2): python scripts/data_collector/yahoo/collector.py normalize_data_1d_extend --old_qlib_dir <dir1> --source_dir <dir2> --normalize_dir <dir3> --region CN --interval 1d

4. dump data: python scripts/dump_bin.py dump_update --csv_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date
4. dump data: python scripts/dump_bin.py dump_update --data_path <dir3> --qlib_dir <dir1> --freq day --date_field_name date --symbol_field_name symbol --exclude_fields symbol,date

5. update instrument(eg. csi300): python python scripts/data_collector/cn_index/collector.py --index_name CSI300 --qlib_dir <dir1> --method parse_instruments

Expand Down
83 changes: 60 additions & 23 deletions scripts/dump_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,39 @@
from qlib.utils import fname_to_code, code_to_fname


def read_as_df(file_path: Union[str, Path], **kwargs) -> pd.DataFrame:
"""
Read a csv or parquet file into a pandas DataFrame.

Parameters
----------
file_path : Union[str, Path]
Path to the data file.
**kwargs :
Additional keyword arguments passed to the underlying pandas
reader.

Returns
-------
pd.DataFrame
"""
file_path = Path(file_path).expanduser()
suffix = file_path.suffix.lower()

keep_keys = {".csv": ("low_memory",)}
kept_kwargs = {}
for k in keep_keys.get(suffix, []):
if k in kwargs:
kept_kwargs[k] = kwargs[k]

if suffix == ".csv":
return pd.read_csv(file_path, **kept_kwargs)
elif suffix == ".parquet":
return pd.read_parquet(file_path, **kept_kwargs)
else:
raise ValueError(f"Unsupported file format: {suffix}")


class DumpDataBase:
INSTRUMENTS_START_FIELD = "start_datetime"
INSTRUMENTS_END_FIELD = "end_datetime"
Expand All @@ -34,7 +67,7 @@ class DumpDataBase:

def __init__(
self,
csv_path: str,
data_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
Expand All @@ -50,7 +83,7 @@ def __init__(

Parameters
----------
csv_path: str
data_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
Expand All @@ -73,7 +106,7 @@ def __init__(
limit_nums: int
Use when debugging, default None
"""
csv_path = Path(csv_path).expanduser()
data_path = Path(data_path).expanduser()
if isinstance(exclude_fields, str):
exclude_fields = exclude_fields.split(",")
if isinstance(include_fields, str):
Expand All @@ -82,9 +115,9 @@ def __init__(
self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields)))
self.file_suffix = file_suffix
self.symbol_field_name = symbol_field_name
self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path])
self.df_files = sorted(data_path.glob(f"*{self.file_suffix}") if data_path.is_dir() else [data_path])
if limit_nums is not None:
self.csv_files = self.csv_files[: int(limit_nums)]
self.df_files = self.df_files[: int(limit_nums)]
self.qlib_dir = Path(qlib_dir).expanduser()
self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser()
if backup_dir is not None:
Expand Down Expand Up @@ -134,13 +167,14 @@ def _get_date(
return _calendars.tolist()

def _get_source_data(self, file_path: Path) -> pd.DataFrame:
df = pd.read_csv(str(file_path.resolve()), low_memory=False)
df[self.date_field_name] = df[self.date_field_name].astype(str).astype("datetime64[ns]")
df = read_as_df(file_path, low_memory=False)
if self.date_field_name in df.columns:
df[self.date_field_name] = pd.to_datetime(df[self.date_field_name])
# df.drop_duplicates([self.date_field_name], inplace=True)
return df

def get_symbol_from_file(self, file_path: Path) -> str:
return fname_to_code(file_path.name[: -len(self.file_suffix)].strip().lower())
return fname_to_code(file_path.stem.strip().lower())

def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
Expand Down Expand Up @@ -274,10 +308,10 @@ def _get_all_date(self):
all_datetime = set()
date_range_list = []
_fun = partial(self._get_date, as_set=True, is_begin_end=True)
with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for file_path, ((_begin_time, _end_time), _set_calendars) in zip(
self.csv_files, executor.map(_fun, self.csv_files)
self.df_files, executor.map(_fun, self.df_files)
):
all_datetime = all_datetime | _set_calendars
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
Expand Down Expand Up @@ -305,9 +339,9 @@ def _dump_instruments(self):
def _dump_features(self):
logger.info("start dump features......")
_dump_func = partial(self._dump_bin, calendar_list=self._calendars_list)
with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as executor:
for _ in executor.map(_dump_func, self.csv_files):
for _ in executor.map(_dump_func, self.df_files):
p_bar.update()

logger.info("end of features dump.\n")
Expand All @@ -325,16 +359,15 @@ def _dump_instruments(self):
_fun = partial(self._get_date, is_begin_end=True)
new_stock_files = sorted(
filter(
lambda x: fname_to_code(x.name[: -len(self.file_suffix)].strip().lower()).upper()
not in self._old_instruments,
self.csv_files,
lambda x: self.get_symbol_from_file(x).upper() not in self._old_instruments,
self.df_files,
)
)
with tqdm(total=len(new_stock_files)) as p_bar:
with ProcessPoolExecutor(max_workers=self.works) as execute:
for file_path, (_begin_time, _end_time) in zip(new_stock_files, execute.map(_fun, new_stock_files)):
if isinstance(_begin_time, pd.Timestamp) and isinstance(_end_time, pd.Timestamp):
symbol = fname_to_code(self.get_symbol_from_file(file_path).lower()).upper()
symbol = self.get_symbol_from_file(file_path).upper()
_dt_map = self._old_instruments.setdefault(symbol, dict())
_dt_map[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_begin_time)
_dt_map[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end_time)
Expand All @@ -359,7 +392,7 @@ def dump(self):
class DumpDataUpdate(DumpDataBase):
def __init__(
self,
csv_path: str,
data_path: str,
qlib_dir: str,
backup_dir: str = None,
freq: str = "day",
Expand All @@ -375,7 +408,7 @@ def __init__(

Parameters
----------
csv_path: str
data_path: str
stock data path or directory
qlib_dir: str
qlib(dump) data director
Expand All @@ -399,7 +432,7 @@ def __init__(
Use when debugging, default None
"""
super().__init__(
csv_path,
data_path,
qlib_dir,
backup_dir,
freq,
Expand Down Expand Up @@ -431,15 +464,19 @@ def _load_all_source_data(self):
logger.info("start load all source data....")
all_df = []

def _read_csv(file_path: Path):
_df = pd.read_csv(file_path, parse_dates=[self.date_field_name])
def _read_df(file_path: Path):
_df = read_as_df(file_path)
if self.date_field_name in _df.columns and not np.issubdtype(
_df[self.date_field_name].dtype, np.datetime64
):
_df[self.date_field_name] = pd.to_datetime(_df[self.date_field_name])
if self.symbol_field_name not in _df.columns:
_df[self.symbol_field_name] = self.get_symbol_from_file(file_path)
return _df

with tqdm(total=len(self.csv_files)) as p_bar:
with tqdm(total=len(self.df_files)) as p_bar:
with ThreadPoolExecutor(max_workers=self.works) as executor:
for df in executor.map(_read_csv, self.csv_files):
for df in executor.map(_read_df, self.df_files):
if not df.empty:
all_df.append(df)
p_bar.update()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_dump_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class TestDumpData(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR)
TestDumpData.DUMP_DATA = DumpDataAll(csv_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
TestDumpData.DUMP_DATA = DumpDataAll(data_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS)
TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")))
provider_uri = str(QLIB_DIR.resolve())
qlib.init(
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_3_dump_features(self):
def test_4_dump_features_simple(self):
stock = self.STOCK_NAMES[0]
dump_data = DumpDataFix(
csv_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
data_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS
)
dump_data.dump()

Expand Down
Loading