Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SMC to initialize compartmental models #2452

Merged
merged 5 commits into from
Apr 28, 2020
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 28, 2020

Addresses #2426

This replaces custom .heuristic() methods with a default generic SMCFilter-based initialization strategy, greatly simplifying the process of creating new compartmental models. Users can still override this default .heuristic() if needed.

🎉 hooray for composable inference 🎉

Tested

  • refactoring covered by existing tests
  • ran python sir.py -p 10000 -d 60 -f 30 --verbose --plot and confirmed good posteriors and mixing:
         mean       std    median      5.0%     95.0%     n_eff     r_hat
 R0      1.48      0.08      1.47      1.36      1.61     16.89      1.09
rho      0.48      0.03      0.49      0.44      0.53     29.28      1.14

image


def step(self, state):
with poutine.block(), poutine.condition(data=state):
params = self.model.global_model()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed exactly?

Copy link
Member Author

@fritzo fritzo Apr 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recall that SMCFilter tracks all state in a state dict and periodically resamples all that state. In this use case the state dict contains values of all sample sites that we will eventually replay to initialize MCMC. In particular it contains sample sites for global variables. Now since resampling changes those sites, and since global_model() not only draws samples but performs computation on those samples, we need to re-perform that computation after each resampling. Hence we block (to avoid this being traced again), and condition on the SMCFilter-managed state to get the latest resampled params each step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants