Skip to content

Incorrect AssertionError message in NPE.train() #1676

@abelaba

Description

@abelaba

🐛 Bug Description

When using NPE and calling the train() method a second time, while setting both force_first_round_loss and resume_training arguments to False. The assertion set on this line is raised https://github.com/sbi-dev/sbi/blob/8762b0e3fdea8804a4327e286342891697cf7f46/sbi/inference/trainers/npe/npe_base.py#L619C10-L631C14.

The assertion message states that append_simulations(theta, x) was called after training the network, even though the append_simulations method wasn't used again after training the network.

🔄 Steps to Reproduce

  1. Python version - 3.10.16 and SBI version - 0.24.0
  2. Minimal code example that triggers the bug
  def simulator(theta):
      return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1

  num_dim = 3
  prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
  theta = prior.sample((300,))
  x = simulator(theta)

  inference = NPE(prior=prior)
  inference.append_simulations(theta, x)

  inference.train()
  inference.train(force_first_round_loss=False, resume_training=False)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions