-
Couldn't load subscription status.
- Fork 58
PyTorch-backed forward simulation #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…breakpoints in debugging
…ovms, and gates (as they appear in TorchOpModel._compute_circuit_outcome_probabilities)
…ding attempts to use that class
…lso SimpleMapForwardSimulator ...) that the dict returned by circuit.expand_instruments_and_separate_povm(...) has at most one element.
…s _compute_circuit_outcome_probabilities
…ting any circuit probabilities
…ce for differentiation yet)
…delmembers (needed to construct differentiable torch tensors). Have a new torch_base property of TPState objects. Need such a property for FullTPOp objects. Unclear how to implement for povms, since right now we`re bypassing the POVM abstraction and going directly into the effects abstraction of the circuit.
…n, rather than only through ConjugatedStatePOVMEffect objects associated with a SeparatePOVMCircuit
…so it allows require_grad=True.
…vn`t used it to speed up derivative computations yet.
…r before converting to a numpy array and writing to array_to_fill in TorchForwardSimulator._bulk_fill_probs_block.
|
@coreyostrove and @sserita this is ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A true tour de force, @rileyjmurray!
I have left some comments and questions for you above.
One additional piece of feedback is that it would be great to add in some additional unit tests for the new Torchable methods, as well as those for the StatelessModel class and the TorchForwardSimulator. These are covered indirectly by the integration test, but including some correctness checks for the individual function calls would go a long way. Additionally, adding an integration test that confirms the various ForwardSimulator classes give approximately the same probability distributions (within some reasonable tolerance) for a few select circuits would be valuable. (All of that said, I am viewing this functionality as still in beta testing, so I likely won't gate keep based on adding these so long as we commit to doing so eventually).
…rams from a dict to a tuple. Remove the torch_handle pattern in the Torchable class.
…torchfwdsim.py and torchable.py.
…StatelessModel.default_to_reverse_ad.
|
@coreyostrove I still plan on adding proper unit tests (in contrast to the existing tests, which are really integration tests). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think everything looks good so far. I'd prefer that we roll back the change removing inheritance from basereps, but I'm open to being convinced. Otherwise my other comments are more questions than explicit action items.
I'm withholding explicit approval until 0.9.13 goes out just as a reminder that we don't want this to merge in before then, and it sounds like you are still working on getting in a few tests as well.
Oh, I also remembered: Tests are failing because of the ComplementPOVM thing. This is not occurring on develop - was this happening now because of one of your TPPOVM changes? I know we had a long-term plan to fix that, but in the short term how can we get this passing again?
Minor clarification, did you mean 0.9.12.3? |
|
Yes I absolutely meant 0.9.12.3. Speaking of which, Tim just approved that so it will be going live as soon as all the tests pass and I get the release drafted up. So just let me know when the tests are good on this branch and I'll give my approval also. |
|
@sserita, @coreyostrove all tests pass, including new tests that I wrote in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all this great work @rileyjmurray! This looks good to me, will merge as soon as beta is clear/0.9.12.3 is out.
This PR introduces
TorchForwardSimulator, a forward simulator (for computing circuit outcome probabilities) based on PyTorch. It uses automatic differentiation to compute the Jacobian of the map from model parameters to circuit outcome probabilities. In the future we could extend it to do computations on a system's GPU, or to use PyTorch-based optimization algorithms instead of pyGSTi's custom algorithms for MLE.Approach
My approach required creating a new ModelMember subclass called
Torchable. This subclass adds two required functions, calledstateless_dataandtorch_base. Their meanings are given below:pyGSTi/pygsti/modelmembers/torchable.py
Lines 18 to 43 in 1ec6909
In principle, TorchForwardSimulator can handle all models for which constituent parameterized ModelMembers are Torchable. So far I've only extended TPState, FullTPOp, and TPPOVM to be Torchable; these are the classes used in "full TP" GST.
The Python file that contains TorchForwardSimulator also defines two helper classes: StatelessCircuit and StatelessModel. I think it's fine to keep these classes as purely internal implementation-specific constructs for now. Depending on future performance optimizations of TorchForwardSimulator we might want to put them elsewhere in pyGSTi.
What should come after this PR
We should compare performance of TorchForwardSimulator to MapForwardSimulator on problems of interest. There's a chance that the former isn't faster than the latter with the current implementation. If that's the case then I should look at possible performance optimizations specifically inside TorchForwardSimulator.
We should add implementations of
stateless_dataandtorch_baseto GST models beyond "Full TP" (in particular I'd like to try CPTP).Incidental changes
My implementation originally interacted with the following evotype classes
When I write
[_slow]in the class names above you can put the empty string or just_slow, depending on the default evotype specified inevotypes.py.To my surprise, I found that interacting with evotypes was neither necessary nor sufficient for what I wanted to accomplish. So while I did make changes in
evotypes/densitymx_slow/to remove unnecessary class inheritances and to add documentation, those changes were only to make life a little easier for future pyGSTi contributors.