Skip to content

Commit 12b06b9

Browse files
authored
DeepSet: Adapt output dimension of invariant module inside the equivariant module (#557) (#561)
* adapt output dim of invariant module in equivariant module See #557. The DeepSet showed bad performance and was not able to learn diverse summary statistics. Reducing the dimension of the output of the invariant module inside the equivariant module improves this, probably because the invidividual information of each set member gains importance compared to the shared information provided by the invariant module. There might be better settings for this, so we might update the default later on. However, this is already an improvement over the previous setting. * DeepSet: adapt docstring to reflect code
1 parent 4e47cc4 commit 12b06b9

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

bayesflow/networks/deep_set/deep_set.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
output_pooling: str = "mean",
2929
mlp_widths_equivariant: Sequence[int] = (64, 64),
3030
mlp_widths_invariant_inner: Sequence[int] = (64, 64),
31-
mlp_widths_invariant_outer: Sequence[int] = (64, 64),
31+
mlp_widths_invariant_outer: Sequence[int] = (64, 4),
3232
mlp_widths_invariant_last: Sequence[int] = (64, 64),
3333
activation: str = "silu",
3434
kernel_initializer: str = "he_normal",
@@ -68,7 +68,7 @@ def __init__(
6868
mlp_widths_invariant_inner : Sequence[int], optional
6969
Widths of the inner MLP layers within the invariant module. Default is (64, 64).
7070
mlp_widths_invariant_outer : Sequence[int], optional
71-
Widths of the outer MLP layers within the invariant module. Default is (64, 64).
71+
Widths of the outer MLP layers within the invariant module. Default is (64, 4).
7272
mlp_widths_invariant_last : Sequence[int], optional
7373
Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
7474
activation : str, optional
@@ -80,7 +80,7 @@ def __init__(
8080
spectral_normalization : bool, optional
8181
Whether to apply spectral normalization to stabilize training. Default is False.
8282
**kwargs
83-
Additional keyword arguments passed to the equivariant and invariant modules.
83+
Additional keyword arguments passed to the base class.
8484
"""
8585

8686
super().__init__(**kwargs)

0 commit comments

Comments
 (0)