@@ -628,6 +628,116 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
628628
629629
630630class TransformerEmbedding (nn .Module ):
631+ r"""
632+ Transformer-based embedding network for **time series** and **image** data.
633+
634+ This module provides a flexible embedding architecture that supports both
635+ (1) 1D / multivariate time series (e.g., experimental trials, temporal signals),
636+ and
637+ (2) image inputs via a lightweight Vision Transformer (ViT)-style patch embedding.
638+
639+ It is designed for simulation-based inference (SBI) workflows where raw
640+ observations must be encoded into fixed-dimensional embeddings before passing
641+ them to a neural density estimator.
642+
643+ Parameters
644+ ----------
645+ pos_emb :
646+ Positional embedding type. One of ``{"rotary", "positional", "none"}``.
647+ pos_emb_base :
648+ Base frequency for rotary positional embeddings.
649+ rms_norm_eps :
650+ Epsilon for RMSNorm layers.
651+ router_jitter_noise :
652+ Noise added when routing tokens to MoE experts.
653+ vit_dropout :
654+ Dropout applied inside ViT patch embedding layers.
655+ mlp_activation :
656+ Activation used inside the feedforward blocks.
657+ is_causal :
658+ If ``True``, applies a causal mask during attention (useful for time-series).
659+ vit :
660+ If ``True``, enables Vision Transformer mode for 2D image inputs.
661+ num_hidden_layers :
662+ Number of transformer encoder blocks.
663+ num_attention_heads :
664+ Number of self-attention heads.
665+ num_key_value_heads :
666+ Number of KV heads (for multi-query attention).
667+ intermediate_size :
668+ Hidden dimension of feedforward network (or MoE experts).
669+ ffn :
670+ Feedforward type. One of ``{"mlp", "moe"}``.
671+ head_dim :
672+ Per-head embedding dimension. If ``None``, inferred as
673+ ``feature_space_dim // num_attention_heads``.
674+ attention_dropout :
675+ Dropout used inside the attention mechanism.
676+ feature_space_dim :
677+ Dimensionality of the token embeddings flowing through the transformer.
678+ - For time-series, this is the model dimension.
679+ - For images (``vit=True``), this is the post-patch-projection embedding size.
680+ final_emb_dimension :
681+ Output embedding dimension. Defaults to ``feature_space_dim // 2``.
682+ image_size :
683+ Input image height/width (only if ``vit=True``).
684+ patch_size :
685+ ViT patch size (only if ``vit=True``).
686+ num_channels :
687+ Number of image channels for ViT mode.
688+ num_local_experts :
689+ Number of MoE experts (only relevant when ``ffn="moe"``).
690+ num_experts_per_tok :
691+ How many experts each token is routed to in MoE mode.
692+
693+ Notes
694+ -----
695+ **Time-series mode (``vit=False``)**
696+ - Inputs of shape ``(batch, seq_len)`` (scalar series) are automatically
697+ projected to ``(batch, seq_len, feature_space_dim)``.
698+ - Inputs of shape ``(batch, seq_len, features)`` are used as-is.
699+ - Causal masking is applied if ``is_causal=True`` (default).
700+ - Suitable for experimental trials, temporal dynamics, or sets of sequential
701+ observations.
702+
703+ **Image mode (``vit=True``)**
704+ - Inputs must have shape ``(batch, channels, height, width)``.
705+ - Images are patchified, linearly projected, and fed to the transformer.
706+ - Causal masking is disabled in this mode.
707+
708+ **Output**
709+ The embedding is obtained by selecting the final token and applying a linear
710+ projection, resulting in a tensor of shape:
711+
712+ ``(batch, final_emb_dimension)``
713+
714+ Example
715+ -------
716+ **1D time-series (default mode)**::
717+
718+ from sbi.neural_nets.embedding_nets import TransformerEmbedding
719+ import torch
720+
721+ x = torch.randn(16, 100) # (batch, seq_len)
722+ emb = TransformerEmbedding(feature_space_dim=64)
723+ z = emb(x)
724+
725+ **Image input (ViT-style)**::
726+
727+ from sbi.neural_nets.embedding_nets import TransformerEmbedding
728+ import torch
729+
730+ x = torch.randn(8, 3, 64, 64) # (batch, C, H, W)
731+ emb = TransformerEmbedding(
732+ vit=True,
733+ image_size=64,
734+ patch_size=8,
735+ num_channels=3,
736+ feature_space_dim=128,
737+ )
738+ z = emb(x)
739+ """
740+
631741 def __init__ (
632742 self ,
633743 * ,
@@ -657,41 +767,41 @@ def __init__(
657767 super ().__init__ ()
658768 """
659769 Main class for constructing a transformer embedding
660- Basic configuration parameters :
770+ Args :
661771 pos_emb: position encoding to be used, currently available:
662- {"rotary", "positional", "none"}
772+ {"rotary", "positional", "none"}
663773 pos_emb_base: base used to construct the positinal encoding
664774 rms_norm_eps: noise added to the rms variance computation
665775 ffn: feedforward layer after used after computing the attention:
666- {"mlp", "moe"}
776+ {"mlp", "moe"}
667777 mlp_activation: activation function to be used within the ffn
668- layer
778+ layer
669779 is_causal: specifies whether causal mask should be created
670780 vit: specifies the whether a convolutional layer should be used for
671- processing images, inspired by the vision transformer
781+ processing images, inspired by the vision transformer
672782 num_hidden_layers: number of transformer blocks
673783 num_attention_heads: number of attention heads
674784 num_key_value_heads: number of key/value heads
675785 feature_space_dim: dimension of the feature vectors
676786 intermediate_size: hidden size of the feedforward layer
677- head_dim: dimension key/query vectors
787+ head_dim: dimension key/query vectors
678788 attention_dropout: value for the dropout of the attention layer
679789
680790 MoE:
681791 router_jitter_noise: noise added before routing the input vectors
682- to the experts
792+ to the experts
683793 num_local_experts: total number of experts
684794 num_experts_per_tok: number of experts each token is assigned to
685795
686796 ViT
687797 feature_space_dim: dimension of the feature vectors after
688- preprocessing the images
798+ preprocessing the images
689799 image_size: dimension of the squared image used to created
690- the positional encoders
691- a rectagular image can be used at training/inference time by
692- resampling the encoders
800+ the positional encoders
801+ a rectagular image can be used at training/inference time by
802+ resampling the encoders
693803 patch_size: size of the square patches used to create the
694- positional encoders
804+ positional encoders
695805 num_channels: number of channels of the input image
696806 vit_dropout: value for the dropout of the attention layer
697807 """
@@ -797,8 +907,8 @@ def forward(
797907 """
798908 Args:
799909 input: input of shape `(batch, seq_len,
800- feature_space_dim)`
801- or `(batch, num_channels, height, width)` if using ViT
910+ feature_space_dim)` or `(batch, num_channels,
911+ height, width)` if using ViT
802912 attention_mask:
803913 attention mask of size `(batch_size, sequence_length)`
804914 output_attentions:
0 commit comments