Skip to content

Commit

Permalink
Merge pull request #326 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Adding PatchTST, renaming d_innner into d_ffn, and refactoring Autofomer
  • Loading branch information
WenjieDu authored Mar 29, 2024
2 parents bf53667 + 2a33326 commit a478836
Show file tree
Hide file tree
Showing 32 changed files with 774 additions and 85 deletions.
2 changes: 1 addition & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ You can also find a simple and quick-start tutorial notebook on Google Colab
n_features=37,
n_layers=2,
d_model=256,
d_inner=128,
d_ffn=128,
n_heads=4,
d_k=64,
d_v=64,
Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/raindrop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Raindrop(BaseNNClassifier):
The dimension of the Transformer encoder backbone.
It is the input dimension of the multi-head self-attention layers.
d_inner :
d_ffn :
The dimension of the layer in the Feed-Forward Networks (FFN).
n_heads :
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
n_classes,
n_layers,
d_model,
d_inner,
d_ffn,
n_heads,
dropout,
d_static=0,
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
n_features,
n_layers,
d_model,
d_inner,
d_ffn,
n_heads,
n_classes,
dropout,
Expand Down
8 changes: 4 additions & 4 deletions pypots/classification/raindrop/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
n_features,
n_layers,
d_model,
d_inner,
d_ffn,
n_heads,
n_classes,
dropout=0.3,
Expand All @@ -59,7 +59,7 @@ def __init__(
self.n_layers = n_layers
self.n_features = n_features
self.d_model = d_model
self.d_inner = d_inner
self.d_ffn = d_ffn
self.n_heads = n_heads
self.n_classes = n_classes
self.dropout = dropout
Expand All @@ -84,13 +84,13 @@ def __init__(
dim_check = n_features * (self.d_ob + d_pe)
assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads"
encoder_layers = TransformerEncoderLayer(
n_features * (self.d_ob + d_pe), n_heads, d_inner, dropout
n_features * (self.d_ob + d_pe), n_heads, d_ffn, dropout
)
else:
dim_check = d_model + d_pe
assert dim_check % n_heads == 0, "dim_check must be divisible by n_heads"
encoder_layers = TransformerEncoderLayer(
d_model + d_pe, n_heads, d_inner, dropout
d_model + d_pe, n_heads, d_ffn, dropout
)
self.transformer_encoder = TransformerEncoder(encoder_layers, n_layers)

Expand Down
2 changes: 2 additions & 0 deletions pypots/imputation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .transformer import Transformer
from .timesnet import TimesNet
from .autoformer import Autoformer
from .patchtst import PatchTST
from .usgan import USGAN

# naive imputation methods
Expand All @@ -26,6 +27,7 @@
"SAITS",
"Transformer",
"TimesNet",
"PatchTST",
"Autoformer",
"BRITS",
"MRNN",
Expand Down
2 changes: 1 addition & 1 deletion pypots/imputation/autoformer/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Dataset class for TimesNet.
Dataset class for Autoformer.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
Expand Down
12 changes: 6 additions & 6 deletions pypots/imputation/autoformer/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
The implementation of Transformer for the partially-observed time-series imputation task.
The implementation of Autoformer for the partially-observed time-series imputation task.
Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
Autoformer: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
Expand Down Expand Up @@ -31,7 +31,7 @@

class Autoformer(BaseNNImputer):
"""The PyTorch implementation of the Autoformer model.
TimesNet is originally proposed by Wu et al. in :cite:`wu2021autoformer`.
Autoformer is originally proposed by Wu et al. in :cite:`wu2021autoformer`.
Parameters
----------
Expand All @@ -56,7 +56,7 @@ class Autoformer(BaseNNImputer):
factor :
The factor of the auto correlation mechanism for the Autoformer model.
moving_avg_kernel_size :
moving_avg_window_size :
The window size of moving average.
dropout :
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
d_model: int,
d_ffn: int,
factor: int,
moving_avg_kernel_size: int,
moving_avg_window_size: int,
dropout: float = 0,
batch_size: int = 32,
epochs: int = 100,
Expand Down Expand Up @@ -149,7 +149,7 @@ def __init__(
self.d_model = d_model
self.d_ffn = d_ffn
self.factor = factor
self.moving_avg_kernel_size = moving_avg_kernel_size
self.moving_avg_window_size = moving_avg_window_size
self.dropout = dropout

# set up the model
Expand All @@ -161,7 +161,7 @@ def __init__(
self.d_model,
self.d_ffn,
self.factor,
self.moving_avg_kernel_size,
self.moving_avg_window_size,
self.dropout,
)
self._send_model_to_given_device()
Expand Down
6 changes: 3 additions & 3 deletions pypots/imputation/autoformer/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
d_model,
d_ffn,
factor,
moving_avg_kernel_size,
moving_avg_window_size,
dropout,
activation="relu",
output_attention=False,
Expand All @@ -38,7 +38,7 @@ def __init__(

self.seq_len = n_steps
self.n_layers = n_layers
self.series_decomp = SeriesDecompositionBlock(moving_avg_kernel_size)
self.series_decomp = SeriesDecompositionBlock(moving_avg_window_size)
self.enc_embedding = DataEmbedding_wo_Pos(
n_features,
d_model,
Expand All @@ -54,7 +54,7 @@ def __init__(
),
d_model,
d_ffn,
moving_avg_kernel_size,
moving_avg_window_size,
dropout,
activation,
)
Expand Down
17 changes: 17 additions & 0 deletions pypots/imputation/patchtst/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
The package of the partially-observed time-series imputation model PatchTST.
Refer to the paper "Wu, H., Xu, J., Wang, J., & Long, M. (2021).
PatchTST: Decomposition transformers with auto-correlation for long-term series forecasting. NeurIPS 2021.".
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause


from .model import PatchTST

__all__ = [
"PatchTST",
]
24 changes: 24 additions & 0 deletions pypots/imputation/patchtst/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Dataset class for PatchTST.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

from typing import Union

from ..saits.data import DatasetForSAITS


class DatasetForPatchTST(DatasetForSAITS):
"""Actually PatchTST uses the same data strategy as SAITS, needs MIT for training."""

def __init__(
self,
data: Union[dict, str],
return_X_ori: bool,
return_labels: bool,
file_type: str = "h5py",
rate: float = 0.2,
):
super().__init__(data, return_X_ori, return_labels, file_type, rate)
Loading

0 comments on commit a478836

Please sign in to comment.