Skip to content

Commit 20f7d47

Browse files
committed
fix docstring indentation and add class level docstring
1 parent 41b2652 commit 20f7d47

File tree

2 files changed

+125
-14
lines changed

2 files changed

+125
-14
lines changed

docs/sbi.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Embedding nets
4646
sbi.neural_nets.embedding_nets.PermutationInvariantEmbedding
4747
sbi.neural_nets.embedding_nets.ResNetEmbedding1D
4848
sbi.neural_nets.embedding_nets.ResNetEmbedding2D
49+
sbi.neural_nets.embedding_nets.TransformerEmbedding
4950

5051

5152
Training

sbi/neural_nets/embedding_nets/transformer.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,116 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
628628

629629

630630
class TransformerEmbedding(nn.Module):
631+
r"""
632+
Transformer-based embedding network for **time series** and **image** data.
633+
634+
This module provides a flexible embedding architecture that supports both
635+
(1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
636+
and
637+
(2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.
638+
639+
It is designed for simulation-based inference (SBI) workflows where raw
640+
observations must be encoded into fixed-dimensional embeddings before passing
641+
them to a neural density estimator.
642+
643+
Parameters
644+
----------
645+
pos_emb :
646+
Positional embedding type. One of ``{"rotary", "positional", "none"}``.
647+
pos_emb_base :
648+
Base frequency for rotary positional embeddings.
649+
rms_norm_eps :
650+
Epsilon for RMSNorm layers.
651+
router_jitter_noise :
652+
Noise added when routing tokens to MoE experts.
653+
vit_dropout :
654+
Dropout applied inside ViT patch embedding layers.
655+
mlp_activation :
656+
Activation used inside the feedforward blocks.
657+
is_causal :
658+
If ``True``, applies a causal mask during attention (useful for time-series).
659+
vit :
660+
If ``True``, enables Vision Transformer mode for 2D image inputs.
661+
num_hidden_layers :
662+
Number of transformer encoder blocks.
663+
num_attention_heads :
664+
Number of self-attention heads.
665+
num_key_value_heads :
666+
Number of KV heads (for multi-query attention).
667+
intermediate_size :
668+
Hidden dimension of feedforward network (or MoE experts).
669+
ffn :
670+
Feedforward type. One of ``{"mlp", "moe"}``.
671+
head_dim :
672+
Per-head embedding dimension. If ``None``, inferred as
673+
``feature_space_dim // num_attention_heads``.
674+
attention_dropout :
675+
Dropout used inside the attention mechanism.
676+
feature_space_dim :
677+
Dimensionality of the token embeddings flowing through the transformer.
678+
- For time-series, this is the model dimension.
679+
- For images (``vit=True``), this is the post-patch-projection embedding size.
680+
final_emb_dimension :
681+
Output embedding dimension. Defaults to ``feature_space_dim // 2``.
682+
image_size :
683+
Input image height/width (only if ``vit=True``).
684+
patch_size :
685+
ViT patch size (only if ``vit=True``).
686+
num_channels :
687+
Number of image channels for ViT mode.
688+
num_local_experts :
689+
Number of MoE experts (only relevant when ``ffn="moe"``).
690+
num_experts_per_tok :
691+
How many experts each token is routed to in MoE mode.
692+
693+
Notes
694+
-----
695+
**Time-series mode (``vit=False``)**
696+
- Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
697+
projected to ``(batch, seq_len, feature_space_dim)``.
698+
- Inputs of shape ``(batch, seq_len, features)`` are used as-is.
699+
- Causal masking is applied if ``is_causal=True`` (default).
700+
- Suitable for experimental trials, temporal dynamics, or sets of sequential
701+
observations.
702+
703+
**Image mode (``vit=True``)**
704+
- Inputs must have shape ``(batch, channels, height, width)``.
705+
- Images are patchified, linearly projected, and fed to the transformer.
706+
- Causal masking is disabled in this mode.
707+
708+
**Output**
709+
The embedding is obtained by selecting the final token and applying a linear
710+
projection, resulting in a tensor of shape:
711+
712+
``(batch, final_emb_dimension)``
713+
714+
Example
715+
-------
716+
**1D time-series (default mode)**::
717+
718+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
719+
import torch
720+
721+
x = torch.randn(16, 100) # (batch, seq_len)
722+
emb = TransformerEmbedding(feature_space_dim=64)
723+
z = emb(x)
724+
725+
**Image input (ViT-style)**::
726+
727+
from sbi.neural_nets.embedding_nets import TransformerEmbedding
728+
import torch
729+
730+
x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
731+
emb = TransformerEmbedding(
732+
vit=True,
733+
image_size=64,
734+
patch_size=8,
735+
num_channels=3,
736+
feature_space_dim=128,
737+
)
738+
z = emb(x)
739+
"""
740+
631741
def __init__(
632742
self,
633743
*,
@@ -657,41 +767,41 @@ def __init__(
657767
super().__init__()
658768
"""
659769
Main class for constructing a transformer embedding
660-
Basic configuration parameters:
770+
Args:
661771
pos_emb: position encoding to be used, currently available:
662-
{"rotary", "positional", "none"}
772+
{"rotary", "positional", "none"}
663773
pos_emb_base: base used to construct the positinal encoding
664774
rms_norm_eps: noise added to the rms variance computation
665775
ffn: feedforward layer after used after computing the attention:
666-
{"mlp", "moe"}
776+
{"mlp", "moe"}
667777
mlp_activation: activation function to be used within the ffn
668-
layer
778+
layer
669779
is_causal: specifies whether causal mask should be created
670780
vit: specifies the whether a convolutional layer should be used for
671-
processing images, inspired by the vision transformer
781+
processing images, inspired by the vision transformer
672782
num_hidden_layers: number of transformer blocks
673783
num_attention_heads: number of attention heads
674784
num_key_value_heads: number of key/value heads
675785
feature_space_dim: dimension of the feature vectors
676786
intermediate_size: hidden size of the feedforward layer
677-
head_dim: dimension key/query vectors
787+
head_dim: dimension key/query vectors
678788
attention_dropout: value for the dropout of the attention layer
679789
680790
MoE:
681791
router_jitter_noise: noise added before routing the input vectors
682-
to the experts
792+
to the experts
683793
num_local_experts: total number of experts
684794
num_experts_per_tok: number of experts each token is assigned to
685795
686796
ViT
687797
feature_space_dim: dimension of the feature vectors after
688-
preprocessing the images
798+
preprocessing the images
689799
image_size: dimension of the squared image used to created
690-
the positional encoders
691-
a rectagular image can be used at training/inference time by
692-
resampling the encoders
800+
the positional encoders
801+
a rectagular image can be used at training/inference time by
802+
resampling the encoders
693803
patch_size: size of the square patches used to create the
694-
positional encoders
804+
positional encoders
695805
num_channels: number of channels of the input image
696806
vit_dropout: value for the dropout of the attention layer
697807
"""
@@ -797,8 +907,8 @@ def forward(
797907
"""
798908
Args:
799909
input: input of shape `(batch, seq_len,
800-
feature_space_dim)`
801-
or `(batch, num_channels, height, width)` if using ViT
910+
feature_space_dim)` or `(batch, num_channels,
911+
height, width)` if using ViT
802912
attention_mask:
803913
attention mask of size `(batch_size, sequence_length)`
804914
output_attentions:

0 commit comments

Comments
 (0)