Skip to content

DeepSet performance issues #557

@vpratz

Description

@vpratz

This thread highlighted performance issues with our deep set implementation, where many inference variables remain entirely uninformed by the learned summary statistics. I have since conducted a few additional experiments and the deep set does not improve its performance, even with exceedingly large summary dim, in contrast to the set transformer.

By trial and error, I have identified the invariant module in the equivariant as a contributor to this specific problem. In EquivariantLayer.call, we have the following code:

    def call(self, input_set: Tensor, training: bool = False, **kwargs) -> Tensor:
        [...]
        input_set = self.input_projector(input_set)

        # Store shape of input_set, will be (batch_size, ..., set_size, some_dim)
        shape = ops.shape(input_set)

        # Example: Output dim is (batch_size, ..., set_size, representation_dim)
        invariant_summary = self.invariant_module(input_set, training=training)
        invariant_summary = ops.expand_dims(invariant_summary, axis=-2)
        tiler = [1] * len(shape)
        tiler[-2] = shape[-2]
        invariant_summary = ops.tile(invariant_summary, tiler)

        # Concatenate each input entry with the repeated invariant embedding
        output_set = ops.concatenate([input_set, invariant_summary], axis=-1)
        [...]

Removing the invariant summary, or reducing its influence by setting mlp_widths_invariant_outer=(64, 2), mlp_widths_invariant_inner=(64, 2), reduces the problem of information getting lost in the deep set. My guess for the mechanism would be that the invariant_summary being identical for each input entry creates a strong co-dependence between the variables, reducing the variability in the outputs that can be achieved by the deep set. Note that it does not eliminate the problem entirely, but brings it to a level more comparable to the set transformer.

However, I'm not sure yet how this will affect performance on problems with a different structure.
@stefanradev93 @Chase-Grajeda As you have worked on this code, what would you suggest are good problems to benchmark this? Also, if you have other ideas on what could be at play here and how we could fix it, please let me know.

Adapted example code from the thread
import keras
import bayesflow as bf
import numpy as np

D = 12
N = 10
# large summary dim to reduce its effect
summary_dim = 200
rng = np.random.default_rng(2025)

def prior():
    variance = rng.uniform(size=D)
    return {"variance": variance}


def likelihood(variance):
    y = rng.normal(loc=np.zeros_like(variance), scale=np.sqrt(variance), size=(N, D))
    return {"y": y}

simulator = bf.make_simulator([prior, likelihood])

validation_data = simulator.sample(100)
print("shapes", keras.tree.map_structure(keras.ops.shape, validation_data))

# summary network options
summary_network = bf.networks.DeepSet(
    summary_dim=summary_dim,
    # mlp_widths_invariant_outer=(64, 2),
    # mlp_widths_invariant_inner=(64, 2)
)
# summary_network = bf.networks.SetTransformer(summary_dim=summary_dim)

workflow = bf.BasicWorkflow(
    simulator=simulator,
    summary_network=summary_network,
    summary_variables="y",
    inference_variables=["variance"],
    initial_learning_rate=5e-4,
    standardize="all",
)

workflow.fit_online(num_batches_per_epoch=1000, epochs=10, validation_data=validation_data, batch_size=32);

test_data = simulator.sample(200)

workflow.plot_default_diagnostics(test_data);

Metadata

Metadata

Assignees

No one assigned

    Labels

    discussionDiscuss a topic or question not necessarily with a clear output in mind.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions