-
Notifications
You must be signed in to change notification settings - Fork 692
[ENH] Kolmogorov Arnold Block for NBeats #1751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@fkiraly @benHeid can you kindly review/merge it so I integrate NBEATSX modification in NBEATS without conflicts as I have asked @julian-fong and he is not working on NBEATSX. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not fully reviewed yet. Will continue in the next days. But I share my current comments so that you already receive some feedbacks.
When training KANs, the grid can be iteratively be refined. I wonder, if there is a way to implement this also here. However, this might probably more difficult and require changes to the trainer. So probably out of scope for this PR. Do you have opinions on that?
@@ -0,0 +1,528 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The license of the original implementation is MIT. So in theory it is okay to copy the file. However, please add some credits at the top of the file.
Alternatively, we could think about adding KAN as a dependency.
@fkiraly do you have any additions on that matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will add the appropriate credits at the top of the file. Additionally, I agree that adding KAN as a dependency—perhaps as a soft dependency—seems like a good idea, especially considering its increasing relevance in time series forecasting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fkiraly pinging you again to check if this is okay for you :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which package are we exactly planning to add as a soft dep?
If it is a single layer, I think copying it over and including the license is perhaps better for now, because we do not have machinery to manage soft dependencies (like scikit-base
or similar).
The proposed design in here sktime/enhancement-proposals#39 would allow that, but right now I think this would require a significant amounts of custom code to handle.
Or, is there an easy way that I am not seeing how the soft dependency import would work for part of the NN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which package are we exactly planning to add as a soft dep?
It is pykan
library , reference https://pypi.org/project/pykan/
If it is a single layer, I think copying it over and including the license is perhaps better for now, because we do not have machinery to manage soft dependencies (like
scikit-base
or similar).The proposed design in here sktime/enhancement-proposals#39 would allow that, but right now I think this would require a significant amounts of custom code to handle.
Or, is there an easy way that I am not seeing how the soft dependency import would work for part of the NN?
yes it is a single layer, only used in NBEATS
. Also the library pykan
has much more, but we only need this. I have implemented what you already suggested above. I have copied it over and included the license.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, makes sense, as long as the license points to the original source and is provided in full form (assuming it hsa the usual reqiurement to reproduce)
Thanks! Will address these reviews soon. |
I'll explore this and share my thoughts. |
… while using KAN blocks in NBEATS.
@benHeid I have addressed the reviews. Kindly review the updated PR.
To address this, I have taken logic from original implementation of pykan library and made custom Callback i.e. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have one last comment. But I also would like to hear @fkiraly opinion on the license of the kan_layer
@benHeid I have addressed the reviews. Kindly review the updated PR, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Could you kindly also add some explanation in the PR fist post about the changes to the structure, and what goes in the common base class? What is going on with the grid callback?
Also requesting review by @phoeenniixx and @PranavBhatP
Optionally, could you check in the layers
module whether we can use that to avoid duplicate layer specifications?
evaluate x on B-spline bases | ||
|
||
Args: | ||
----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please use Parameters
like in sktime
and follow the numpydoc style docstrings, No need to add colon (:)
): | ||
"""' | ||
Initialize a KANLayer | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this docstring above just after the class definition?
shown to consistently outperform N-BEATS. | ||
|
||
Args: | ||
stack_types: One of the following values: “generic”, “seasonality" or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use numpydoc style doctsring
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if | ||
used as ensemble) outperformed all other methods including ensembles of | ||
traditional statical methods in the M4 competition. The M4 competition is | ||
arguably the most important benchmark for univariate time series forecasting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to mention the paper for KAN
blocks here and how it works (just an overview) and how it is different from the original NBeats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned the reference with details.
…ferences, and extend NBEATSKAN test cases.
Updated the first post to include explanations on the structure changes, common base class, and grid callback as requested.
I checked it and currently blocks are reused from |
I think we can also move some of the layers (like Also, a question for @fkiraly: Imo, we could also move networks present in |
Yes, in the future, |
Yes this can also be added! But this is a |
What I was essentially thinking was that |
I agree, it is not used anywhere right now, but we may need it in future, then it will be handy to have it in As you said, maybe @benHeid, @fkiraly can also provide us a new perspective... |
I think ultimately all layers should move to |
Description
Fixes: #1741
This PR adds Kolmogorov Arnold(KAN) Blocks in NBeats and also does refactoring of NBeats. Implementation of KAN blocks' layers is taken from original paper code.
Changes in Structure
Introduced the
NBEATSKAN
module, which enables usage of KAN blocks within theNBEATS
architecture.Integrated
KANLayer
logic, implemented inkan_layer.py
, which handles KAN-specific operations such as:Imported
KANLayer
tosubmodules.py
for block operations, allowingNBEATSKAN
to delegate block-level behavior throughuse_kan=True
.Added the
NBEATSAdapter
class to encapsulate common methods shared by bothNBEATS
andNBEATSKAN
, including:__init__
), which is separately defined in each class to maintain architectural flexibility.GridUpdateCallback
When training KAN-based models, the grid can be iteratively refined during training for better performance.
To support this, logic from the original
[pykan](https://github.com/KindXiaoming/pykan)
implementation has been adapted to define a custom callback namedGridUpdateCallback
. This callback automatically updates the grid at specified training steps, improving model accuracy and convergence.This callback has been tested successfully and demonstrates improved results in practice.
An example usage is provided in:
examples/nbeats_with_kan.py
Checklist
pre-commit install
.To run hooks independent of commit, execute
pre-commit run --all-files
Make sure to have fun coding!