Skip to content

[ENH] Added TCN forecaster in aeon/forecasting/deep_learning #2938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
af94a86
Basedeep forecaster added
lucifer4073 May 22, 2025
7bb161e
Merge upstream main to basedlf
lucifer4073 May 24, 2025
d2ee9ec
init for basedlf added
lucifer4073 May 26, 2025
ab3030c
test file and axis added for basedeepforecaster
lucifer4073 Jun 15, 2025
1f202db
test locally
lucifer4073 Jun 15, 2025
14eb41f
dlf corrected
lucifer4073 Jun 15, 2025
d1a2aab
tf soft dep added
lucifer4073 Jun 22, 2025
865ed14
Merge remote-tracking branch 'upstream/main' into basedlf
lucifer4073 Jun 22, 2025
5fb72c7
tcn network added
lucifer4073 Jul 6, 2025
3434757
tcn_net pytest added
lucifer4073 Jul 6, 2025
a73c5f7
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into tcn_net
lucifer4073 Jul 6, 2025
f2f393d
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
c2b6231
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
c602e39
tcn_network updated with default params
lucifer4073 Jul 6, 2025
ad2fc01
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 6, 2025
05a0f35
TCN forecaster added
lucifer4073 Jul 7, 2025
2f3c98b
tcn reshaped
lucifer4073 Jul 7, 2025
dd5b014
Merge branch 'main' of https://github.com/aeon-toolkit/aeon into tcn_fst
lucifer4073 Jul 7, 2025
e630a99
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 7, 2025
f6447b1
tcn changed
lucifer4073 Jul 8, 2025
30d862a
base fst changed
lucifer4073 Jul 8, 2025
135a98d
Merge branch 'tcn_net' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 8, 2025
9b9d266
TCN forecaster updated
lucifer4073 Jul 8, 2025
78b2f3d
test file corrected
lucifer4073 Jul 8, 2025
79fe3e2
Merge branch 'basedlf' of https://github.com/lucifer4073/aeon into tc…
lucifer4073 Jul 8, 2025
49be666
tcn updated
lucifer4073 Jul 8, 2025
7bacdac
tcn updated
lucifer4073 Jul 8, 2025
9a1b878
tcnfst updated with net
lucifer4073 Jul 8, 2025
08dadec
doctest corrected
lucifer4073 Jul 8, 2025
b167479
merge tcn_net
lucifer4073 Jul 8, 2025
086c5a4
changes made
lucifer4073 Jul 13, 2025
a1f68cd
Merge branch 'main' into tcn_fst
lucifer4073 Jul 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: PR pytest
on:
push:
branches:
- main
- tcn_fst
pull_request:
paths:
- "aeon/**"
Expand Down
9 changes: 9 additions & 0 deletions aeon/forecasting/deep_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Initialization for aeon forecasting deep learning module."""

__all__ = [
"BaseDeepForecaster",
"TCNForecaster",
]

from aeon.forecasting.deep_learning._tcn import TCNForecaster
from aeon.forecasting.deep_learning.base import BaseDeepForecaster
150 changes: 150 additions & 0 deletions aeon/forecasting/deep_learning/_tcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""TCNForecaster module for deep learning forecasting in aeon."""

from __future__ import annotations

__maintainer__ = []

__all__ = ["TCNForecaster"]

from typing import Any

from aeon.forecasting.deep_learning.base import BaseDeepForecaster
from aeon.networks._tcn import TCNNetwork


class TCNForecaster(BaseDeepForecaster):
"""A deep learning forecaster using Temporal Convolutional Network (TCN).

It leverages the `TCNNetwork` from aeon's network module
to build the architecture suitable for forecasting tasks.

Parameters
----------
horizon : int, default=1
Forecasting horizon, the number of steps ahead to predict.
window : int, default=10
The window size for creating input sequences.
batch_size : int, default=32
Batch size for training the model.
epochs : int, default=100
Number of epochs to train the model.
verbose : int, default=0
Verbosity mode (0, 1, or 2).
optimizer : str or tf.keras.optimizers.Optimizer, default='adam'
Optimizer to use for training.
loss : str or tf.keras.losses.Loss, default='mse'
Loss function for training.
random_state : int, default=None
Seed for random number generators.
axis : int, default=0
Axis along which to apply the forecaster.
n_blocks : list of int, default=[16, 16, 16]
List specifying the number of output channels for each layer of the
TCN. The length determines the depth of the network.
kernel_size : int, default=2
Size of the convolutional kernel in the TCN.
dropout : float, default=0.2
Dropout rate applied after each convolutional layer for
regularization.
"""

_tags = {
"python_dependencies": ["tensorflow"],
"capability:horizon": True,
"capability:multivariate": True,
"capability:exogenous": False,
"capability:univariate": True,
}

def __init__(
self,
horizon=1,
window=10,
batch_size=32,
epochs=100,
verbose=0,
optimizer="adam",
loss="mse",
random_state=None,
axis=0,
n_blocks=None,
kernel_size=2,
dropout=0.2,
):
super().__init__(
horizon=horizon,
window=window,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
optimizer=optimizer,
random_state=random_state,
axis=axis,
loss=loss,
)
self.n_blocks = n_blocks
self.kernel_size = kernel_size
self.dropout = dropout

def _build_model(self, input_shape):
"""Build the TCN model for forecasting.

Parameters
----------
input_shape : tuple
Shape of input data, typically (window, num_inputs).

Returns
-------
model : tf.keras.Model
Compiled Keras model with TCN architecture.
"""
import tensorflow as tf

# Initialize the TCN network with the updated parameters
network = TCNNetwork(
n_blocks=self.n_blocks if self.n_blocks is not None else [16, 16, 16],
kernel_size=self.kernel_size,
dropout=self.dropout,
)

# Build the network with the given input shape
input_layer, output = network.build_network(input_shape=input_shape)

# Create the final model
model = tf.keras.Model(inputs=input_layer, outputs=output)
return model

# Added to handle __name__ in tests (class-level access)
@classmethod
def _get_test_params(
cls, parameter_set: str = "default"
) -> dict[str, Any] | list[dict[str, Any]]:
"""
Return testing parameter settings for the estimator.

Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
For forecasters, a "default" set of parameters should be provided for
general testing, and a "results_comparison" set for comparing against
previously recorded results if the general set does not produce suitable
probabilities to compare against.

Returns
-------
params : dict or list of dict, default={}
Parameters to create testing instances of the class.
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
"""
param = {
"epochs": 10,
"batch_size": 4,
"n_blocks": [8, 8],
"kernel_size": 2,
"dropout": 0.1,
}
return [param]
Loading
Loading