Skip to content

Conversation

@rileyjmurray
Copy link
Contributor

@rileyjmurray rileyjmurray commented Jan 18, 2024

This PR introduces TorchForwardSimulator, a forward simulator (for computing circuit outcome probabilities) based on PyTorch. It uses automatic differentiation to compute the Jacobian of the map from model parameters to circuit outcome probabilities. In the future we could extend it to do computations on a system's GPU, or to use PyTorch-based optimization algorithms instead of pyGSTi's custom algorithms for MLE.

Approach

My approach required creating a new ModelMember subclass called Torchable. This subclass adds two required functions, called stateless_data and torch_base. Their meanings are given below:

def stateless_data(self) -> Tuple:
"""
Return this ModelMember's data that is considered constant for purposes of model fitting.
Note: the word "stateless" here is used in the sense of object-oriented programming.
"""
raise NotImplementedError()
@staticmethod
def torch_base(sd : Tuple, t_param : Tensor) -> Tensor:
"""
Suppose "obj" is an instance of some Torchable subclass. If we compute
vec = obj.to_vector()
t_param = torch.from_numpy(vec)
sd = obj.stateless_data()
t = type(obj).torch_base(sd, t_param)
then t will be a PyTorch Tensor that represents "obj" in a canonical numerical way.
The meaning of "canonical" is implementation dependent. If type(obj) implements
the ``.base`` attribute, then a reasonable implementation will probably satisfy
np.allclose(obj.base, t.numpy()).
"""
raise NotImplementedError()

In principle, TorchForwardSimulator can handle all models for which constituent parameterized ModelMembers are Torchable. So far I've only extended TPState, FullTPOp, and TPPOVM to be Torchable; these are the classes used in "full TP" GST.

The Python file that contains TorchForwardSimulator also defines two helper classes: StatelessCircuit and StatelessModel. I think it's fine to keep these classes as purely internal implementation-specific constructs for now. Depending on future performance optimizations of TorchForwardSimulator we might want to put them elsewhere in pyGSTi.

What should come after this PR

We should compare performance of TorchForwardSimulator to MapForwardSimulator on problems of interest. There's a chance that the former isn't faster than the latter with the current implementation. If that's the case then I should look at possible performance optimizations specifically inside TorchForwardSimulator.

We should add implementations of stateless_data and torch_base to GST models beyond "Full TP" (in particular I'd like to try CPTP).

Incidental changes

My implementation originally interacted with the following evotype classes

    <class 'pygsti.evotypes.densitymx[_slow].statereps.StateRepDense'>
    <class 'pygsti.evotypes.densitymx[_slow].opreps.OpRepDenseSuperop'>
    <class 'pygsti.evotypes.densitymx[_slow].effectreps.EffectRepConjugatedState'>

When I write [_slow] in the class names above you can put the empty string or just _slow, depending on the default evotype specified in evotypes.py.

To my surprise, I found that interacting with evotypes was neither necessary nor sufficient for what I wanted to accomplish. So while I did make changes in evotypes/densitymx_slow/ to remove unnecessary class inheritances and to add documentation, those changes were only to make life a little easier for future pyGSTi contributors.

…ovms, and gates (as they appear in TorchOpModel._compute_circuit_outcome_probabilities)
…lso SimpleMapForwardSimulator ...) that the dict returned by circuit.expand_instruments_and_separate_povm(...) has at most one element.
…delmembers (needed to construct differentiable torch tensors). Have a new torch_base property of TPState objects. Need such a property for FullTPOp objects. Unclear how to implement for povms, since right now we`re bypassing the POVM abstraction and going directly into the effects abstraction of the circuit.
…n, rather than only through ConjugatedStatePOVMEffect objects associated with a SeparatePOVMCircuit
…vn`t used it to speed up derivative computations yet.
…r before converting to a numpy array and writing to array_to_fill in TorchForwardSimulator._bulk_fill_probs_block.
@rileyjmurray
Copy link
Contributor Author

@coreyostrove and @sserita this is ready for review.

Copy link
Contributor

@coreyostrove coreyostrove left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A true tour de force, @rileyjmurray!

I have left some comments and questions for you above.

One additional piece of feedback is that it would be great to add in some additional unit tests for the new Torchable methods, as well as those for the StatelessModel class and the TorchForwardSimulator. These are covered indirectly by the integration test, but including some correctness checks for the individual function calls would go a long way. Additionally, adding an integration test that confirms the various ForwardSimulator classes give approximately the same probability distributions (within some reasonable tolerance) for a few select circuits would be valuable. (All of that said, I am viewing this functionality as still in beta testing, so I likely won't gate keep based on adding these so long as we commit to doing so eventually).

@rileyjmurray
Copy link
Contributor Author

@coreyostrove I still plan on adding proper unit tests (in contrast to the existing tests, which are really integration tests).

Copy link
Contributor

@sserita sserita left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everything looks good so far. I'd prefer that we roll back the change removing inheritance from basereps, but I'm open to being convinced. Otherwise my other comments are more questions than explicit action items.
I'm withholding explicit approval until 0.9.13 goes out just as a reminder that we don't want this to merge in before then, and it sounds like you are still working on getting in a few tests as well.

Oh, I also remembered: Tests are failing because of the ComplementPOVM thing. This is not occurring on develop - was this happening now because of one of your TPPOVM changes? I know we had a long-term plan to fix that, but in the short term how can we get this passing again?

@coreyostrove
Copy link
Contributor

I'm withholding explicit approval until 0.9.13 goes out just as a reminder that we don't want this to merge in before then, and it sounds like you are still working on getting in a few tests as well.

Minor clarification, did you mean 0.9.12.3?

@sserita
Copy link
Contributor

sserita commented May 31, 2024

Yes I absolutely meant 0.9.12.3. Speaking of which, Tim just approved that so it will be going live as soon as all the tests pass and I get the release drafted up.

So just let me know when the tests are good on this branch and I'll give my approval also.

@rileyjmurray
Copy link
Contributor Author

@sserita, @coreyostrove all tests pass, including new tests that I wrote in test_forwardsim.py. Merge when ready!

Copy link
Contributor

@sserita sserita left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all this great work @rileyjmurray! This looks good to me, will merge as soon as beta is clear/0.9.12.3 is out.

@sserita sserita merged commit 14f2859 into develop Jun 11, 2024
@sserita sserita deleted the torch-fwdsim branch January 20, 2025 23:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants