-
Notifications
You must be signed in to change notification settings - Fork 206
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug Description
When using NPE and calling the train() method a second time, while setting both force_first_round_loss and resume_training arguments to False. The assertion set on this line is raised https://github.com/sbi-dev/sbi/blob/8762b0e3fdea8804a4327e286342891697cf7f46/sbi/inference/trainers/npe/npe_base.py#L619C10-L631C14.
The assertion message states that append_simulations(theta, x) was called after training the network, even though the append_simulations method wasn't used again after training the network.
🔄 Steps to Reproduce
- Python version - 3.10.16 and SBI version - 0.24.0
- Minimal code example that triggers the bug
def simulator(theta):
return 1.0 + theta + torch.randn(theta.shape, device=theta.device) * 0.1
num_dim = 3
prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))
theta = prior.sample((300,))
x = simulator(theta)
inference = NPE(prior=prior)
inference.append_simulations(theta, x)
inference.train()
inference.train(force_first_round_loss=False, resume_training=False)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working