Skip to content

Commit fc243fd

Browse files
you-n-gzhupr
andauthored
Fix Models (#483)
* fix gat dataset * fix tft model * Update tft.py * Fix tft.py Co-authored-by: Pengrong Zhu <zhu.pengrong@foxmail.com>
1 parent b6a8bd5 commit fc243fd

File tree

8 files changed

+50
-8
lines changed

8 files changed

+50
-8
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ All the models listed above are runnable with ``Qlib``. Users can find the confi
308308
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
309309

310310
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
311+
- **NOTE**: Each baseline has different environment dependencies, please make sure that your python version aligns with the requirements(e.g. TFT only supports Python 3.6~3.7 due to the limitation of `tensorflow==1.15.0`)
311312
312313
## Run multiple models
313314
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parallel running the same model for multiple times as well, and this will be fixed in the future development too.)

examples/benchmarks/TFT/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Users can follow the ``workflow_by_code_tft.py`` to run the benchmark.
99

1010
### Notes
11-
1. Please be **aware** that this script can only support `Python 3.5 - 3.8`.
11+
1. Please be **aware** that this script can only support `Python 3.6 - 3.7`.
1212
2. If the CUDA version on your machine is not 10.0, please remember to run the following commands `conda install anaconda cudatoolkit=10.0` and `conda install cudnn` on your machine.
1313
3. The model must run in GPU, or an error will be raised.
1414
4. New datasets should be registered in ``data_formatters``, for detail please visit the source.
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
tensorflow-gpu==1.15.0
2-
numpy == 1.19.4
3-
pandas==1.1.0
2+
pandas==1.1.0

examples/benchmarks/TFT/tft.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
from pathlib import Path
5+
from typing import Union
46
import numpy as np
57
import pandas as pd
68
import tensorflow.compat.v1 as tf
@@ -243,7 +245,7 @@ def extract_numerical_data(data):
243245
# extract_numerical_data(targets), extract_numerical_data(p90_forecast),
244246
# 0.9)
245247
tf.keras.backend.set_session(default_keras_session)
246-
print("Training completed.".format(dte.datetime.now()))
248+
print("Training completed at {}.".format(dte.datetime.now()))
247249
# ===========================Training Process===========================
248250

249251
def predict(self, dataset):
@@ -289,3 +291,24 @@ def finetune(self, dataset: DatasetH):
289291
dataset for finetuning
290292
"""
291293
pass
294+
295+
def to_pickle(self, path: Union[Path, str]):
296+
"""
297+
Tensorflow model can't be dumped directly.
298+
So the data should be save seperatedly
299+
300+
**TODO**: Please implement the function to load the files
301+
302+
Parameters
303+
----------
304+
path : Union[Path, str]
305+
the target path to be dumped
306+
"""
307+
# save tensorflow model
308+
# path = Path(path)
309+
# path.mkdir(parents=True)
310+
# self.model.save(path)
311+
312+
# save qlib model wrapper
313+
self.model = None
314+
super(TFTModel, self).to_pickle(path / "qlib_model")

qlib/contrib/model/pytorch_gats_ts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
class DailyBatchSampler(Sampler):
2929
def __init__(self, data_source):
30-
3130
self.data_source = data_source
3231
# calculate number of samples in each batch
3332
self.daily_count = pd.Series(index=self.data_source.get_index()).groupby("datetime").size().values

qlib/utils/serial.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,22 @@ def get_backend(cls):
122122
return dill
123123
else:
124124
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
125+
126+
@staticmethod
127+
def general_dump(obj, path: Union[Path, str]):
128+
"""
129+
A general dumping method for object
130+
131+
Parameters
132+
----------
133+
obj : object
134+
the object to be dumped
135+
path : Union[Path, str]
136+
the target path the data will be dumped
137+
"""
138+
path = Path(path)
139+
if isinstance(obj, Serializable):
140+
obj.to_pickle(path)
141+
else:
142+
with path.open("wb") as f:
143+
pickle.dump(obj, f)

qlib/workflow/recorder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
from qlib.utils.serial import Serializable
45
import mlflow, logging
56
import shutil, os, pickle, tempfile, codecs, pickle
67
from pathlib import Path
@@ -307,8 +308,8 @@ def save_objects(self, local_path=None, artifact_path=None, **kwargs):
307308
else:
308309
temp_dir = Path(tempfile.mkdtemp()).resolve()
309310
for name, data in kwargs.items():
310-
with (temp_dir / name).open("wb") as f:
311-
pickle.dump(data, f)
311+
path = temp_dir / name
312+
Serializable.general_dump(data, path)
312313
self.client.log_artifact(self.id, temp_dir / name, artifact_path)
313314
shutil.rmtree(temp_dir)
314315

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"statsmodels",
5353
"xlrd>=1.0.0",
5454
"plotly==4.12.0",
55-
"matplotlib==3.3",
55+
"matplotlib>=3.3",
5656
"tables>=3.6.1",
5757
"pyyaml>=5.3.1",
5858
"mlflow>=1.12.1",

0 commit comments

Comments
 (0)