Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/sbi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Embedding nets
sbi.neural_nets.embedding_nets.PermutationInvariantEmbedding
sbi.neural_nets.embedding_nets.ResNetEmbedding1D
sbi.neural_nets.embedding_nets.ResNetEmbedding2D
sbi.neural_nets.embedding_nets.TransformerEmbedding


Training
Expand Down
300 changes: 224 additions & 76 deletions sbi/neural_nets/embedding_nets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,91 +628,234 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:


class TransformerEmbedding(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be for adding a short explanatory class docstring here?

e.g., for an SBI user working with time series or images but not so familiar with transformers, give a concise overview how they can use this class? e.g., what means "vit" (for images), what means "is_causal" (for time series). etc. not a tutorial, just a brief high-level explanation. Maybe even with a short code Example block.

When we add this docstring here on the top class level then it will show up nicely in the Sphinx Documentation, e.g., like with the EnsemblePosterior here: https://sbi.readthedocs.io/en/latest/reference/_autosummary/sbi.inference.EnsemblePosterior.html#sbi.inference.EnsemblePosterior

def __init__(self, config):
r"""
Transformer-based embedding network for **time series** and **image** data.

This module provides a flexible embedding architecture that supports both
(1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
and
(2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.

It is designed for simulation-based inference (SBI) workflows where raw
observations must be encoded into fixed-dimensional embeddings before passing
them to a neural density estimator.
Comment on lines +632 to +641
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


Parameters
----------
pos_emb :
Positional embedding type. One of ``{"rotary", "positional", "none"}``.
pos_emb_base :
Base frequency for rotary positional embeddings.
rms_norm_eps :
Epsilon for RMSNorm layers.
router_jitter_noise :
Noise added when routing tokens to MoE experts.
vit_dropout :
Dropout applied inside ViT patch embedding layers.
mlp_activation :
Activation used inside the feedforward blocks.
is_causal :
If ``True``, applies a causal mask during attention (useful for time-series).
vit :
If ``True``, enables Vision Transformer mode for 2D image inputs.
num_hidden_layers :
Number of transformer encoder blocks.
num_attention_heads :
Number of self-attention heads.
num_key_value_heads :
Number of KV heads (for multi-query attention).
intermediate_size :
Hidden dimension of feedforward network (or MoE experts).
ffn :
Feedforward type. One of ``{"mlp", "moe"}``.
head_dim :
Per-head embedding dimension. If ``None``, inferred as
``feature_space_dim // num_attention_heads``.
attention_dropout :
Dropout used inside the attention mechanism.
feature_space_dim :
Dimensionality of the token embeddings flowing through the transformer.
- For time-series, this is the model dimension.
- For images (``vit=True``), this is the post-patch-projection embedding size.
final_emb_dimension :
Output embedding dimension. Defaults to ``feature_space_dim // 2``.
image_size :
Input image height/width (only if ``vit=True``).
patch_size :
ViT patch size (only if ``vit=True``).
num_channels :
Number of image channels for ViT mode.
num_local_experts :
Number of MoE experts (only relevant when ``ffn="moe"``).
num_experts_per_tok :
How many experts each token is routed to in MoE mode.

Notes
-----
**Time-series mode (``vit=False``)**
- Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
projected to ``(batch, seq_len, feature_space_dim)``.
- Inputs of shape ``(batch, seq_len, features)`` are used as-is.
- Causal masking is applied if ``is_causal=True`` (default).
- Suitable for experimental trials, temporal dynamics, or sets of sequential
observations.

**Image mode (``vit=True``)**
- Inputs must have shape ``(batch, channels, height, width)``.
- Images are patchified, linearly projected, and fed to the transformer.
- Causal masking is disabled in this mode.

**Output**
The embedding is obtained by selecting the final token and applying a linear
projection, resulting in a tensor of shape:

``(batch, final_emb_dimension)``

Example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

-------
**1D time-series (default mode)**::

from sbi.neural_nets.embedding_nets import TransformerEmbedding
import torch

x = torch.randn(16, 100) # (batch, seq_len)
emb = TransformerEmbedding(feature_space_dim=64)
z = emb(x)

**Image input (ViT-style)**::

from sbi.neural_nets.embedding_nets import TransformerEmbedding
import torch

x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
emb = TransformerEmbedding(
vit=True,
image_size=64,
patch_size=8,
num_channels=3,
feature_space_dim=128,
)
z = emb(x)
"""

def __init__(
self,
*,
pos_emb: str = "rotary",
pos_emb_base: float = 10e4,
rms_norm_eps: float = 1e-05,
router_jitter_noise: float = 0.0,
vit_dropout: float = 0.5,
mlp_activation: str = "gelu",
is_causal: bool = True,
vit: bool = False,
num_hidden_layers: int = 4,
num_attention_heads: int = 12,
num_key_value_heads: int = 12,
intermediate_size: int = 256,
ffn: str = "mlp",
head_dim: Optional[int] = None,
attention_dropout: float = 0.5,
feature_space_dim: int,
final_emb_dimension: Optional[int] = None,
image_size: Optional[int] = None,
patch_size: Optional[int] = None,
num_channels: Optional[int] = None,
num_local_experts: Optional[int] = None,
num_experts_per_tok: Optional[int] = None,
):
super().__init__()
"""
Main class for constructing a transformer embedding
Basic configuration parameters:
pos_emb (string): position encoding to be used, currently available:
{"rotary", "positional", "none"}
pos_emb_base (float): base used to construct the positinal encoding
rms_norm_eps (float): noise added to the rms variance computation
ffn (string): feedforward layer after used after computing the attention:
{"mlp", "moe"}
mlp_activation (string): activation function to be used within the ffn
layer
is_causal (bool): specifies whether causal mask should be created
vit (bool): specifies the whether a convolutional layer should be used for
processing images, inspired by the vision transformer
num_hidden_layer (int): number of transformer blocks
num_attention_heads (int): number of attention heads
num_key_value_heads (int): number of key/value heads
feature_space_dim (int): dimension of the feature vectors
intermediate_size (int): hidden size of the feedforward layer
head_dim (int): dimension key/query vectors
attention_dropout (float): value for the dropout of the attention layer

Args:
pos_emb: position encoding to be used, currently available:
{"rotary", "positional", "none"}
pos_emb_base: base used to construct the positinal encoding
rms_norm_eps: noise added to the rms variance computation
ffn: feedforward layer after used after computing the attention:
{"mlp", "moe"}
mlp_activation: activation function to be used within the ffn
layer
is_causal: specifies whether causal mask should be created
vit: specifies the whether a convolutional layer should be used for
processing images, inspired by the vision transformer
num_hidden_layers: number of transformer blocks
num_attention_heads: number of attention heads
num_key_value_heads: number of key/value heads
feature_space_dim: dimension of the feature vectors
intermediate_size: hidden size of the feedforward layer
head_dim: dimension key/query vectors
attention_dropout: value for the dropout of the attention layer

MoE:
router_jitter_noise (float): noise added before routing the input vectors
to the experts
num_local_experts (int): total number of experts
num_experts_per_tok (int): number of experts each token is assigned to
router_jitter_noise: noise added before routing the input vectors
to the experts
num_local_experts: total number of experts
num_experts_per_tok: number of experts each token is assigned to

ViT
feature_space_dim (int): dimension of the feature vectors after
preprocessing the images
image_size (int): dimension of the squared image used to created
the positional encoders
a rectagular image can be used at training/inference time by
resampling the encoders
patch_size (int): size of the square patches used to create the
positional encoders
num_channels (int): number of channels of the input image
vit_dropout (float): value for the dropout of the attention layer
"""
self.config = {
"pos_emb": "rotary",
"pos_emb_base": 10e4,
"rms_norm_eps": 1e-05,
"router_jitter_noise": 0.0,
"vit_dropout": 0.5,
"mlp_activation": "gelu",
"is_causal": True,
"vit": False,
"num_hidden_layers": 4,
"num_attention_heads": 12,
"num_key_value_heads": 12,
"intermediate_size": 256,
"ffn": "mlp",
"head_dim": None,
"attention_dropout": 0.5,
}
feature_space_dim: dimension of the feature vectors after
preprocessing the images
image_size: dimension of the squared image used to created
the positional encoders
a rectagular image can be used at training/inference time by
resampling the encoders
patch_size: size of the square patches used to create the
positional encoders
num_channels: number of channels of the input image
vit_dropout: value for the dropout of the attention layer
"""

self.config = dict(
pos_emb=pos_emb,
pos_emb_base=pos_emb_base,
rms_norm_eps=rms_norm_eps,
router_jitter_noise=router_jitter_noise,
vit_dropout=vit_dropout,
mlp_activation=mlp_activation,
is_causal=is_causal,
vit=vit,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
intermediate_size=intermediate_size,
ffn=ffn,
head_dim=head_dim,
attention_dropout=attention_dropout,
feature_space_dim=feature_space_dim,
image_size=image_size,
patch_size=patch_size,
num_channels=num_channels,
num_local_experts=num_local_experts,
num_experts_per_tok=num_experts_per_tok,
)

self.config.update(config)
self.preprocess = ViTEmbeddings(self.config) if vit else IdentityEncoder()

self.preprocess = (
ViTEmbeddings(self.config) if self.config["vit"] else IdentityEncoder()
)
self._supports_scalar_series = not vit
if self._supports_scalar_series:
self.scalar_projection = nn.Linear(
1, feature_space_dim
) # proj 1D → model dim

self.layers = nn.ModuleList([
TransformerBlock(self.config)
for _ in range(self.config["num_hidden_layers"])
TransformerBlock(self.config) for _ in range(num_hidden_layers)
])
self.is_causal = self.config["is_causal"] and not self.config["vit"]
self.is_causal = is_causal and not vit

self.norm = RMSNorm(
self.config["feature_space_dim"], eps=self.config["rms_norm_eps"]
)
final_emb_dimension = self.config.get(
"final_emb_dimension", self.config["feature_space_dim"] // 2
)
if not config["vit"] and final_emb_dimension > self.config["feature_space_dim"]:
self.norm = RMSNorm(feature_space_dim, eps=rms_norm_eps)

if final_emb_dimension is None:
final_emb_dimension = feature_space_dim // 2

if not vit and final_emb_dimension > feature_space_dim:
raise ValueError(
"The final embedding dimension should be equal or smaller than "
"the input dimension"
"The final embedding dimension should be "
"equal or smaller than the input dimension"
)
self.aggregator = nn.Linear(
self.config["feature_space_dim"],
feature_space_dim,
final_emb_dimension,
)
self.causal_mask_cache_ = (None, None, None)
Expand Down Expand Up @@ -764,21 +907,26 @@ def forward(
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
input (`torch.Tensor`): input of shape `(batch, seq_len,
feature_space_dim)`
or `(batch, num_channels, height, width)` if using ViT
attention_mask (`torch.Tensor`, *optional*):
input:
input of shape `(batch, seq_len, feature_space_dim)`
or `(batch, num_channels, height, width)` if using ViT
attention_mask:
attention mask of size `(batch_size, sequence_length)`
output_attentions (`bool`, *optional*):
output_attentions:
Whether or not to return the attention tensors
cache_attention_mask (`bool`, *optional*):
cache_attention_mask:
Whether or not to cache the expanded attention mask, useful if using
multiple batched with identical input shapes
kwargs (`dict`, *optional*):
kwargs:
Arbitrary kwargs
"""

input = self.preprocess(input)

if self._supports_scalar_series and input.ndim == 2:
input = input.unsqueeze(-1) # (B, T, 1)
input = self.scalar_projection(input) # (B, T, feature_space_dim)

if self.is_causal:
dtype, device = input.dtype, input.device

Expand Down
28 changes: 26 additions & 2 deletions tests/embedding_net_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def repeat_to_match_shape(x, input_shape):
)
@pytest.mark.parametrize("seq_length", (24, 13, 5))
def test_transformer_embedding(config, seq_length):
net = TransformerEmbedding(config=config)
net = TransformerEmbedding(**config)

def simulator(theta):
x = MultivariateNormal(
Expand Down Expand Up @@ -291,7 +291,7 @@ def simulator(theta):
)
@pytest.mark.parametrize("img_shape", ((3, 32, 24), (3, 64, 64)))
def test_transformer_vitembedding(config, img_shape):
net = TransformerEmbedding(config=config)
net = TransformerEmbedding(**config)

def simulator(theta):
x = MultivariateNormal(
Expand All @@ -311,6 +311,30 @@ def simulator(theta):
_test_helper_embedding_net(prior, xo, simulator, net)


@pytest.mark.parametrize("seq_length", (10, 20))
def test_transformer_embedding_scalar_timeseries(seq_length):
net = TransformerEmbedding(
pos_emb="rotary",
feature_space_dim=32,
num_attention_heads=4,
num_key_value_heads=4,
vit=False,
head_dim=None,
intermediate_size=64,
num_hidden_layers=2,
attention_dropout=0.1,
)

def simulator(theta):
batch = theta.shape[0]
return torch.randn(batch, seq_length) + theta[:, 0:1]

xo = torch.randn(1, seq_length) # shape: (1, T)
prior = MultivariateNormal(torch.zeros(1), torch.eye(1))

_test_helper_embedding_net(prior, xo, simulator, net)


def _test_helper_embedding_net(prior, xo, simulator, net):
estimator_provider = posterior_nn(
"mdn",
Expand Down