Skip to content

Commit 8da9ec0

Browse files
committed
fix docstring indentation and add class level docstring
1 parent 41b2652 commit 8da9ec0

File tree

2 files changed

+126
-14
lines changed

2 files changed

+126
-14
lines changed

docs/sbi.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Embedding nets
4646
sbi.neural_nets.embedding_nets.PermutationInvariantEmbedding
4747
sbi.neural_nets.embedding_nets.ResNetEmbedding1D
4848
sbi.neural_nets.embedding_nets.ResNetEmbedding2D
49+
sbi.neural_nets.embedding_nets.TransformerEmbedding
4950

5051

5152
Training

sbi/neural_nets/embedding_nets/transformer.py

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,116 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
628628

629629

630630
class TransformerEmbedding(nn.Module):
631+
r"""
632+
Transformer-based embedding network for **time series** and **image** data.
633+
634+
This module provides a flexible embedding architecture that supports both
635+
(1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
636+
and
637+
(2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.
638+
639+
It is designed for simulation-based inference (SBI) workflows where raw
640+
observations must be encoded into fixed-dimensional embeddings before passing
641+
them to a neural density estimator.
642+
643+
Parameters
644+
----------
645+
pos_emb :
646+
Positional embedding type. One of ``{"rotary", "positional", "none"}``.
647+
pos_emb_base :
648+
Base frequency for rotary positional embeddings.
649+
rms_norm_eps :
650+
Epsilon for RMSNorm layers.
651+
router_jitter_noise :
652+
Noise added when routing tokens to MoE experts.
653+
vit_dropout :
654+
Dropout applied inside ViT patch embedding layers.
655+
mlp_activation :
656+
Activation used inside the feedforward blocks.
657+
is_causal :
658+
If ``True``, applies a causal mask during attention (useful for time-series).
659+
vit :
660+
If ``True``, enables Vision Transformer mode for 2D image inputs.
661+
num_hidden_layers :
662+
Number of transformer encoder blocks.
663+
num_attention_heads :
664+
Number of self-attention heads.
665+
num_key_value_heads :
666+
Number of KV heads (for multi-query attention).
667+
intermediate_size :
668+
Hidden dimension of feedforward network (or MoE experts).
669+
ffn :
670+
Feedforward type. One of ``{"mlp", "moe"}``.
671+
head_dim :
672+
Per-head embedding dimension. If ``None``, inferred as
673+
``feature_space_dim // num_attention_heads``.
674+
attention_dropout :
675+
Dropout used inside the attention mechanism.
676+
feature_space_dim :
677+
Dimensionality of the token embeddings flowing through the transformer.
678+
- For time-series, this is the model dimension.
679+
- For images (``vit=True``), this is the post-patch-projection embedding size.
680+
final_emb_dimension :
681+
Output embedding dimension. Defaults to ``feature_space_dim // 2``.
682+
image_size :
683+
Input image height/width (only if ``vit=True``).
684+
patch_size :
685+
ViT patch size (only if ``vit=True``).
686+
num_channels :
687+
Number of image channels for ViT mode.
688+
num_local_experts :
689+
Number of MoE experts (only relevant when ``ffn="moe"``).
690+
num_experts_per_tok :
691+
How many experts each token is routed to in MoE mode.
692+
693+
Notes
694+
-----
695+
**Time-series mode (``vit=False``)**
696+
- Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
697+
projected to ``(batch, seq_len, feature_space_dim)``.
698+
- Inputs of shape ``(batch, seq_len, features)`` are used as-is.
699+
- Causal masking is applied if ``is_causal=True`` (default).
700+
- Suitable for experimental trials, temporal dynamics, or sets of sequential
701+
observations.
702+
703+
**Image mode (``vit=True``)**
704+
- Inputs must have shape ``(batch, channels, height, width)``.
705+
- Images are patchified, linearly projected, and fed to the transformer.
706+
- Causal masking is disabled in this mode.
707+
708+
**Output**
709+
The embedding is obtained by selecting the final token and applying a linear
710+
projection, resulting in a tensor of shape:
711+
712+
``(batch, final_emb_dimension)``
713+
714+
Example
715+
-------
716+
**1D time-series (default mode)**::
717+
718+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
719+
import torch
720+
721+
x = torch.randn(16, 100) # (batch, seq_len)
722+
emb = TransformerEmbedding(feature_space_dim=64)
723+
z = emb(x)
724+
725+
**Image input (ViT-style)**::
726+
727+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
728+
import torch
729+
730+
x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
731+
emb = TransformerEmbedding(
732+
vit=True,
733+
image_size=64,
734+
patch_size=8,
735+
num_channels=3,
736+
feature_space_dim=128,
737+
)
738+
z = emb(x)
739+
"""
740+
631741
def __init__(
632742
self,
633743
*,
@@ -657,41 +767,42 @@ def __init__(
657767
super().__init__()
658768
"""
659769
Main class for constructing a transformer embedding
660-
Basic configuration parameters:
770+
771+
Args:
661772
pos_emb: position encoding to be used, currently available:
662-
{"rotary", "positional", "none"}
773+
{"rotary", "positional", "none"}
663774
pos_emb_base: base used to construct the positinal encoding
664775
rms_norm_eps: noise added to the rms variance computation
665776
ffn: feedforward layer after used after computing the attention:
666-
{"mlp", "moe"}
777+
{"mlp", "moe"}
667778
mlp_activation: activation function to be used within the ffn
668-
layer
779+
layer
669780
is_causal: specifies whether causal mask should be created
670781
vit: specifies the whether a convolutional layer should be used for
671-
processing images, inspired by the vision transformer
782+
processing images, inspired by the vision transformer
672783
num_hidden_layers: number of transformer blocks
673784
num_attention_heads: number of attention heads
674785
num_key_value_heads: number of key/value heads
675786
feature_space_dim: dimension of the feature vectors
676787
intermediate_size: hidden size of the feedforward layer
677-
head_dim: dimension key/query vectors
788+
head_dim: dimension key/query vectors
678789
attention_dropout: value for the dropout of the attention layer
679790
680791
MoE:
681792
router_jitter_noise: noise added before routing the input vectors
682-
to the experts
793+
to the experts
683794
num_local_experts: total number of experts
684795
num_experts_per_tok: number of experts each token is assigned to
685796
686797
ViT
687798
feature_space_dim: dimension of the feature vectors after
688-
preprocessing the images
799+
preprocessing the images
689800
image_size: dimension of the squared image used to created
690-
the positional encoders
691-
a rectagular image can be used at training/inference time by
692-
resampling the encoders
801+
the positional encoders
802+
a rectagular image can be used at training/inference time by
803+
resampling the encoders
693804
patch_size: size of the square patches used to create the
694-
positional encoders
805+
positional encoders
695806
num_channels: number of channels of the input image
696807
vit_dropout: value for the dropout of the attention layer
697808
"""
@@ -797,8 +908,8 @@ def forward(
797908
"""
798909
Args:
799910
input: input of shape `(batch, seq_len,
800-
feature_space_dim)`
801-
or `(batch, num_channels, height, width)` if using ViT
911+
feature_space_dim)` or `(batch, num_channels,
912+
height, width)` if using ViT
802913
attention_mask:
803914
attention mask of size `(batch_size, sequence_length)`
804915
output_attentions:

0 commit comments

Comments
 (0)