Skip to content

Conversation

@satwiksps
Copy link
Contributor

Support scalar time-series and simplify config handling in TransformerEmbedding

This PR improves the usability of TransformerEmbedding by adding native support for scalar (1D) time-series inputs and removing the requirement to manually specify a full configuration dictionary.

Problem

What this PR does

  • Adds automatic projection from scalar inputs to the required feature_space_dim via a lazily-initialized input_proj layer.
  • Extends the forward pass to handle:
    • (batch, seq_len)
    • (batch, seq_len, 1)
    • (batch, seq_len, D) (existing behavior)
  • Ensures compatibility with the attention mechanism (e.g., head dimension constraints).
  • Adds new test test_transformer_embedding_scalar_timeseries to verify:
    • Correct handling of 1D inputs
    • Proper projection behavior
    • Successful integration into embedding API
  • Updates docstrings to reflect new behavior.

Why this is valuable

  • Makes the transformer embedding consistent with other embedding nets (RNN, CNN, LRU) that already accept scalar time-series.
  • Eliminates user friction by making TransformerEmbedding usable out of the box.
  • Enables simplified tutorials, workflows, and example code for time-series inference.

Testing

  • New test added in embedding_net_test.py.
  • All transformer-specific tests pass.
  • Pre-commit (ruff, formatting, linting) passed successfully.

Checklist

  • Code changes follow project style.
  • Added targeted tests.
  • Updated docstrings
  • Pre-commit hooks passed.
  • No breaking changes introduced.

Closes #1696

@codecov
Copy link

codecov bot commented Nov 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@2c216c2). Learn more about missing BASE report.
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1703   +/-   ##
=======================================
  Coverage        ?   84.67%           
=======================================
  Files           ?      137           
  Lines           ?    11493           
  Branches        ?        0           
=======================================
  Hits            ?     9732           
  Misses          ?     1761           
  Partials        ?        0           
Flag Coverage Δ
unittests 84.67% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/neural_nets/embedding_nets/transformer.py 93.40% <100.00%> (ø)

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @satwiksps , good first draft!

I suggest to make all the config options actual arguments to __init__ to give full control to the user and have explicit type hints and defaults.

We could even think about using a dataclass TransformerConfig, but this would create additional overhead for the user, having to instantiate the dataclass config object externall, not sure. what's your take here?

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Added more suggestion for making it look even better in terms of documentation. Would be great if you could add this as well 🙏

super().__init__()
"""
Main class for constructing a transformer embedding
Basic configuration parameters:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Basic configuration parameters:
Args:

Main class for constructing a transformer embedding
Basic configuration parameters:
pos_emb (string): position encoding to be used, currently available:
pos_emb: position encoding to be used, currently available:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if arg docstrings break the line please add a small indent to set them apart from other arguments.

is_causal (bool): specifies whether causal mask should be created
vit (bool): specifies the whether a convolutional layer should be used for
is_causal: specifies whether causal mask should be created
vit: specifies the whether a convolutional layer should be used for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add indent in line below.

ffn (string): feedforward layer after used after computing the attention:
pos_emb_base: base used to construct the positinal encoding
rms_norm_eps: noise added to the rms variance computation
ffn: feedforward layer after used after computing the attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add indent in line below.

ffn: feedforward layer after used after computing the attention:
{"mlp", "moe"}
mlp_activation (string): activation function to be used within the ffn
mlp_activation: activation function to be used within the ffn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add indent in line below.

return embeddings


class TransformerEmbedding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be for adding a short explanatory class docstring here?

e.g., for an SBI user working with time series or images but not so familiar with transformers, give a concise overview how they can use this class? e.g., what means "vit" (for images), what means "is_causal" (for time series). etc. not a tutorial, just a brief high-level explanation. Maybe even with a short code Example block.

When we add this docstring here on the top class level then it will show up nicely in the Sphinx Documentation, e.g., like with the EnsemblePosterior here: https://sbi.readthedocs.io/en/latest/reference/_autosummary/sbi.inference.EnsemblePosterior.html#sbi.inference.EnsemblePosterior

@satwiksps satwiksps force-pushed the feat/transformerembedding-1d-support branch 2 times, most recently from 20f7d47 to 8da9ec0 Compare November 20, 2025 20:18
@satwiksps satwiksps force-pushed the feat/transformerembedding-1d-support branch from 8da9ec0 to 4fb79ea Compare November 20, 2025 20:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support scalar time-series and simplify config handling in TransformerEmbedding

2 participants