-
Notifications
You must be signed in to change notification settings - Fork 76
Description
This thread highlighted performance issues with our deep set implementation, where many inference variables remain entirely uninformed by the learned summary statistics. I have since conducted a few additional experiments and the deep set does not improve its performance, even with exceedingly large summary dim, in contrast to the set transformer.
By trial and error, I have identified the invariant module in the equivariant as a contributor to this specific problem. In EquivariantLayer.call, we have the following code:
def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
[...]
input_set = self.input_projector(input_set)
# Store shape of input_set, will be (batch_size, ..., set_size, some_dim)
shape = ops.shape(input_set)
# Example: Output dim is (batch_size, ..., set_size, representation_dim)
invariant_summary = self.invariant_module(input_set, training=training)
invariant_summary = ops.expand_dims(invariant_summary, axis=-2)
tiler = [1] * len(shape)
tiler[-2] = shape[-2]
invariant_summary = ops.tile(invariant_summary, tiler)
# Concatenate each input entry with the repeated invariant embedding
output_set = ops.concatenate([input_set, invariant_summary], axis=-1)
[...]
Removing the invariant summary, or reducing its influence by setting mlp_widths_invariant_outer=(64, 2), mlp_widths_invariant_inner=(64, 2)
, reduces the problem of information getting lost in the deep set. My guess for the mechanism would be that the invariant_summary
being identical for each input entry creates a strong co-dependence between the variables, reducing the variability in the outputs that can be achieved by the deep set. Note that it does not eliminate the problem entirely, but brings it to a level more comparable to the set transformer.
However, I'm not sure yet how this will affect performance on problems with a different structure.
@stefanradev93 @Chase-Grajeda As you have worked on this code, what would you suggest are good problems to benchmark this? Also, if you have other ideas on what could be at play here and how we could fix it, please let me know.
Adapted example code from the thread
import keras
import bayesflow as bf
import numpy as np
D = 12
N = 10
# large summary dim to reduce its effect
summary_dim = 200
rng = np.random.default_rng(2025)
def prior():
variance = rng.uniform(size=D)
return {"variance": variance}
def likelihood(variance):
y = rng.normal(loc=np.zeros_like(variance), scale=np.sqrt(variance), size=(N, D))
return {"y": y}
simulator = bf.make_simulator([prior, likelihood])
validation_data = simulator.sample(100)
print("shapes", keras.tree.map_structure(keras.ops.shape, validation_data))
# summary network options
summary_network = bf.networks.DeepSet(
summary_dim=summary_dim,
# mlp_widths_invariant_outer=(64, 2),
# mlp_widths_invariant_inner=(64, 2)
)
# summary_network = bf.networks.SetTransformer(summary_dim=summary_dim)
workflow = bf.BasicWorkflow(
simulator=simulator,
summary_network=summary_network,
summary_variables="y",
inference_variables=["variance"],
initial_learning_rate=5e-4,
standardize="all",
)
workflow.fit_online(num_batches_per_epoch=1000, epochs=10, validation_data=validation_data, batch_size=32);
test_data = simulator.sample(200)
workflow.plot_default_diagnostics(test_data);