Skip to content

Commit 3a69644

Browse files
han-olLarsKuevpratzstefanradev93marvinschmitt
authored
Add projectors to DeepSet (#453)
* v2.0.2 (#447) * [no ci] notebook tests: increase timeout, fix platform/backend dependent code Torch is very slow, so I had to increase the timeout accordingly. * Enable use of summary networks with functional API again (#434) * summary networks: add tests for using functional API * fix build functions for use with functional API * [no ci] docs: add GitHub and Discourse links, reorder navbar * [no ci] docs: acknowledge scikit-learn website * [no ci] docs: capitalize navigation headings * More tests (#437) * fix docs of coupling flow * add additional tests * Automatically run slow tests when main is involved. (#438) In addition, this PR limits the slow test to Windows and Python 3.10. The choices are somewhat arbitrary, my thought was to test the setup not covered as much through use by the devs. * Update dispatch * Update dispatching distributions * Improve workflow tests with multiple summary nets / approximators * Fix zombie find_distribution import * Add readme entry [no ci] * Update README: NumFOCUS affiliation, awesome-abi list (#445) * fix is_symbolic_tensor * remove multiple batch sizes, remove multiple python version tests, remove update-workflows branch from workflow style tests, add __init__ and conftest to test_point_approximators (#443) * implement compile_from_config and get_compile_config (#442) * implement compile_from_config and get_compile_config * add optimizer build to compile_from_config * Fix Optimal Transport for Compiled Contexts (#446) * remove the is_symbolic_tensor check because this would otherwise skip the whole function for compiled contexts * skip pyabc test * fix sinkhorn and log_sinkhorn message formatting for jax by making the warning message worse * update dispatch tests for more coverage * Update issue templates (#448) * Hotfix Version 2.0.1 (#431) * fix optimal transport config (#429) * run linter * [skip-ci] bump version to 2.0.1 * Update issue templates * Robustify kwargs passing inference networks, add class variables * fix convergence method to debug for non-log sinkhorn * Bump optimal transport default to False * use logging.info for backend selection instead of logging.debug * fix model comparison approximator * improve docs and type hints * improve One-Sample T-Test Notebook: - use torch as default backend - reduce range of N so users of jax won't be stuck with a slow notebook - use BayesFlow built-in MLP instead of keras.Sequential solution - general code cleanup * remove backend print * [skip ci] turn all single-quoted strings into double-quoted strings * turn all single-quoted strings into double-quoted strings amend to trigger workflow --------- Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> Co-authored-by: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com> * drafting feature * Initialize projectors for invariant and equivariant DeepSet layers * implement requested changes and improve activation --------- Co-authored-by: Lars <lars@kuehmichel.de> Co-authored-by: Valentin Pratz <git@valentinpratz.de> Co-authored-by: Valentin Pratz <112951103+vpratz@users.noreply.github.com> Co-authored-by: stefanradev93 <stefan.radev93@gmail.com> Co-authored-by: Marvin Schmitt <35921281+marvinschmitt@users.noreply.github.com>
1 parent 62675c3 commit 3a69644

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

bayesflow/networks/deep_set/deep_set.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
mlp_widths_invariant_inner: Sequence[int] = (64, 64),
3131
mlp_widths_invariant_outer: Sequence[int] = (64, 64),
3232
mlp_widths_invariant_last: Sequence[int] = (64, 64),
33-
activation: str = "gelu",
33+
activation: str = "silu",
3434
kernel_initializer: str = "he_normal",
3535
dropout: int | float | None = 0.05,
3636
spectral_normalization: bool = False,
@@ -72,7 +72,7 @@ def __init__(
7272
mlp_widths_invariant_last : Sequence[int], optional
7373
Widths of the MLP layers in the final invariant transformation. Default is (64, 64).
7474
activation : str, optional
75-
Activation function used throughout the network, such as "gelu". Default is "gelu".
75+
Activation function used throughout the network, such as "gelu". Default is "silu".
7676
kernel_initializer : str, optional
7777
Initialization strategy for kernel weights, such as "he_normal". Default is "he_normal".
7878
dropout : int, float, or None, optional

bayesflow/networks/deep_set/equivariant_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
kernel_initializer=kernel_initializer,
9595
spectral_normalization=spectral_normalization,
9696
)
97+
self.out_fc_projector = keras.layers.Dense(mlp_widths_equivariant[-1], kernel_initializer=kernel_initializer)
9798

9899
self.layer_norm = layers.LayerNormalization() if layer_norm else None
99100

@@ -137,7 +138,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
137138
output_set = ops.concatenate([input_set, invariant_summary], axis=-1)
138139

139140
# Pass through final equivariant transform + residual
140-
output_set = input_set + self.equivariant_fc(output_set, training=training)
141+
out_fc = self.equivariant_fc(output_set, training=training)
142+
out_projected = self.out_fc_projector(out_fc)
143+
output_set = input_set + out_projected
144+
141145
if self.layer_norm is not None:
142146
output_set = self.layer_norm(output_set, training=training)
143147

bayesflow/networks/deep_set/invariant_layer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
kernel_initializer=kernel_initializer,
7575
spectral_normalization=spectral_normalization,
7676
)
77+
self.inner_projector = keras.layers.Dense(mlp_widths_inner[-1], kernel_initializer=kernel_initializer)
7778

7879
self.outer_fc = MLP(
7980
mlp_widths_outer,
@@ -82,6 +83,7 @@ def __init__(
8283
kernel_initializer=kernel_initializer,
8384
spectral_normalization=spectral_normalization,
8485
)
86+
self.outer_projector = keras.layers.Dense(mlp_widths_outer[-1], kernel_initializer=kernel_initializer)
8587

8688
# Pooling function as keras layer for sum decomposition: inner( pooling( inner(set) ) )
8789
if pooling_kwargs is None:
@@ -106,8 +108,10 @@ def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
106108
"""
107109

108110
set_summary = self.inner_fc(input_set, training=training)
111+
set_summary = self.inner_projector(set_summary)
109112
set_summary = self.pooling_layer(set_summary, training=training)
110113
set_summary = self.outer_fc(set_summary, training=training)
114+
set_summary = self.outer_projector(set_summary)
111115
return set_summary
112116

113117
@sanitize_input_shape

0 commit comments

Comments
 (0)