Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start a contrib.epidemiology module #2437

Merged
merged 23 commits into from
Apr 24, 2020
Merged

Start a contrib.epidemiology module #2437

merged 23 commits into from
Apr 24, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 23, 2020

Addresses #2426

This refactors examples/sir_hmc.py into a more reusable framework that separates (1) clear user-facing model specification from (2) intricate effect handler code for inference and prediction. I have preserved the existing examples/sir_hmc.py as a minipyro-like concrete explanation of the new module; it thus serves as architecture documentation. This also makes small changes to sir_hmc.py to keep the two versions aligned (the old concrete and new abstract versions).

Whereas sir_hmc.py contains four models, I was able to reduce user-facing modeling code down to a few methods and only a single duplication: we need a forward .transition_fwd() method and also a .transition_bwd() method; these contain reverse versions of dynamic equations and cannot easily be automatically generated.

After this PR I plan to simplify sir_hmc.py by moving all the fancy DCT and spline stuff into the new module.

Tested

  • Added a simple unit test
  • Ran plotting code locally to verify results are the same as sir_hmc.py

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

docs/source/contrib.epidemiology.rst Outdated Show resolved Hide resolved

# Sample initial values.
state = self.initialize(params)
state = {i: torch.tensor(float(value)) for i, value in state.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for exactly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.initialize currently returns deterministic initial states. Because those states are used in torch.nn.functional.pad they must be Python scalars. However in this function those initial states need to start out as tensors.

The state dict in this function acts similarly to the state dict in SMCFilter: it is a framework-managed storage location where users can read and write values. Note the self.transition_fwd(params, state, t) call below updates this state dict in-place.

logp = pyro.distributions.hmm._sequential_logmatmulexp(logp)
logp = logp.reshape(-1).logsumexp(0)
warn_if_nan(logp)
pyro.factor("transition", logp)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened to the previous -log(4)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have neglected that from this model because (1) it does not affect inference (it does not change gradients), and (2) we will probably be moving to stochastic initial state soon anyway.

@martinjankowiak martinjankowiak merged commit c11b83c into dev Apr 24, 2020
@fritzo fritzo deleted the contrib-epidemiology branch April 28, 2020 01:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants