-
Notifications
You must be signed in to change notification settings - Fork 206
Add 1D time-series support and automatic input projection to TransformerEmbedding #1703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add 1D time-series support and automatic input projection to TransformerEmbedding #1703
Conversation
…dd corresponding tests
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1703 +/- ##
=======================================
Coverage ? 84.67%
=======================================
Files ? 137
Lines ? 11493
Branches ? 0
=======================================
Hits ? 9732
Misses ? 1761
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more.
|
janfb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @satwiksps , good first draft!
I suggest to make all the config options actual arguments to __init__ to give full control to the user and have explicit type hints and defaults.
We could even think about using a dataclass TransformerConfig, but this would create additional overhead for the user, having to instantiate the dataclass config object externall, not sure. what's your take here?
janfb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Added more suggestion for making it look even better in terms of documentation. Would be great if you could add this as well 🙏
| super().__init__() | ||
| """ | ||
| Main class for constructing a transformer embedding | ||
| Basic configuration parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Basic configuration parameters: | |
| Args: |
| Main class for constructing a transformer embedding | ||
| Basic configuration parameters: | ||
| pos_emb (string): position encoding to be used, currently available: | ||
| pos_emb: position encoding to be used, currently available: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if arg docstrings break the line please add a small indent to set them apart from other arguments.
| is_causal (bool): specifies whether causal mask should be created | ||
| vit (bool): specifies the whether a convolutional layer should be used for | ||
| is_causal: specifies whether causal mask should be created | ||
| vit: specifies the whether a convolutional layer should be used for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add indent in line below.
| ffn (string): feedforward layer after used after computing the attention: | ||
| pos_emb_base: base used to construct the positinal encoding | ||
| rms_norm_eps: noise added to the rms variance computation | ||
| ffn: feedforward layer after used after computing the attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add indent in line below.
| ffn: feedforward layer after used after computing the attention: | ||
| {"mlp", "moe"} | ||
| mlp_activation (string): activation function to be used within the ffn | ||
| mlp_activation: activation function to be used within the ffn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add indent in line below.
| return embeddings | ||
|
|
||
|
|
||
| class TransformerEmbedding(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you be for adding a short explanatory class docstring here?
e.g., for an SBI user working with time series or images but not so familiar with transformers, give a concise overview how they can use this class? e.g., what means "vit" (for images), what means "is_causal" (for time series). etc. not a tutorial, just a brief high-level explanation. Maybe even with a short code Example block.
When we add this docstring here on the top class level then it will show up nicely in the Sphinx Documentation, e.g., like with the EnsemblePosterior here: https://sbi.readthedocs.io/en/latest/reference/_autosummary/sbi.inference.EnsemblePosterior.html#sbi.inference.EnsemblePosterior
20f7d47 to
8da9ec0
Compare
8da9ec0 to
4fb79ea
Compare
Support scalar time-series and simplify config handling in
TransformerEmbeddingThis PR improves the usability of
TransformerEmbeddingby adding native support for scalar (1D) time-series inputs and removing the requirement to manually specify a full configuration dictionary.Problem
TransformerEmbeddingdid not support scalar time-series shaped(batch, seq_len)or(batch, seq_len, 1).configdictionary, which was unintuitive and prone to errors.What this PR does
feature_space_dimvia a lazily-initializedinput_projlayer.(batch, seq_len)(batch, seq_len, 1)(batch, seq_len, D)(existing behavior)test_transformer_embedding_scalar_timeseriesto verify:Why this is valuable
TransformerEmbeddingusable out of the box.Testing
embedding_net_test.py.ruff, formatting, linting) passed successfully.Checklist
Closes #1696