Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
fbe23e9
resolve deprecation warning
rileyjmurray Jan 17, 2024
985404f
tiny bugfix
rileyjmurray Jan 17, 2024
8f36247
starting point for building the TorchForwardSimulator class
rileyjmurray Jan 17, 2024
842c0f7
notes
rileyjmurray Jan 18, 2024
1b698e6
infrastructure
rileyjmurray Jan 19, 2024
cbc15b5
I understand how I am stuck and will get help
rileyjmurray Jan 24, 2024
b3ac3da
change list comprehension into for-loop in order to simplify setting …
rileyjmurray Jan 25, 2024
73363d1
leave comments describing object inheritance structures for states, p…
rileyjmurray Jan 25, 2024
9983d1b
comments indicating class types of povm-related objects
rileyjmurray Jan 25, 2024
bd345f6
improve readability
rileyjmurray Jan 26, 2024
2e76f32
remove unnecessary dependence of certain Evotypes on trivial Cython b…
rileyjmurray Jan 26, 2024
bd82b41
left out of last commit
rileyjmurray Jan 26, 2024
e158c21
comments explaining that densitymx_slow is really "superket_slow"
rileyjmurray Jan 26, 2024
c6b4d8f
left out of last commit
rileyjmurray Jan 26, 2024
ae73090
remove commented-out functions which I now clearly understand we do n…
rileyjmurray Jan 26, 2024
ffa7ea0
remove abstraction layers in TorchForwardSimulator
rileyjmurray Jan 26, 2024
d787025
remove more abstractions
rileyjmurray Jan 27, 2024
b510b2e
remove references to new TorchLayerRules class and discussion surroun…
rileyjmurray Jan 27, 2024
6fc59dd
make an apparent limitation of TorchForwardSimulator (and I suppose a…
rileyjmurray Jan 27, 2024
6aac2af
remove unused function
rileyjmurray Jan 27, 2024
107b26b
explicitly override the function that iterates over circuits and call…
rileyjmurray Jan 27, 2024
c1fcfc2
get array representations of all quantities as prep work before compu…
rileyjmurray Jan 31, 2024
761496c
use torch to compute circuit probabilities (infrastructure not in pla…
rileyjmurray Feb 1, 2024
abdfdc7
progress toward bypassing explicit calls to _rep fields of various mo…
rileyjmurray Feb 1, 2024
9b56b2a
more progress on modelmember.torch_base(...) pattern
rileyjmurray Feb 1, 2024
0c9b103
demonstrate how we can access povm data through the TPPOVM abstractio…
rileyjmurray Feb 1, 2024
243b757
write basic TPPOVM.torch_base function. Need to modify that function …
rileyjmurray Feb 1, 2024
0bea829
forward simulation codepath that computes gradients seems to work. Ha…
rileyjmurray Feb 1, 2024
b88643a
can build the entire vector of outcome probabilities as a torch Tenso…
rileyjmurray Feb 1, 2024
c1eacb3
make a function that lets us access the torch representation of compu…
rileyjmurray Feb 2, 2024
3ef9502
simplified torch_cache
rileyjmurray Feb 2, 2024
7073544
step toward what we need for torch jacfwd function
rileyjmurray Feb 2, 2024
aa5c4e7
progress toward functional evaluation in TPPOVM.torch_base. Need to a…
rileyjmurray Feb 2, 2024
0bc3736
add a static_torch_base function
rileyjmurray Feb 2, 2024
6658c47
progress toward statelessness
rileyjmurray Feb 2, 2024
852d8a6
more functional
rileyjmurray Feb 2, 2024
b6bc0f0
created (and put to work) a new StatelessModel helper class
rileyjmurray Feb 2, 2024
9855144
I can successfully call jacfwd and get reasonable output. Next step i…
rileyjmurray Feb 2, 2024
f85716b
IT IS ALIVE
rileyjmurray Feb 2, 2024
14f1af4
note some opportunities for improved efficiency
rileyjmurray Feb 2, 2024
2c6be95
simplified StatelessModel and StatelessCircuit
rileyjmurray Feb 3, 2024
23207f7
remove unnecessary comments
rileyjmurray Feb 3, 2024
3a04a31
clean up TorchForwardSimulator
rileyjmurray Feb 3, 2024
6c2e5f3
revert change that helped with debugging once-upon-a-time, but wasn`t…
rileyjmurray Feb 3, 2024
eb79162
Have meaningful comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
3461335
improve comments for classes in evotypes/densitymx_slow/
rileyjmurray Feb 3, 2024
1cc944c
remove unused function
rileyjmurray Feb 3, 2024
0e2f051
undo change
rileyjmurray Feb 3, 2024
cfa9232
removed unused file
rileyjmurray Feb 3, 2024
cf05d9a
documentation
rileyjmurray Feb 6, 2024
f312b92
remove comment logged as GitHub Issue #397
rileyjmurray Feb 6, 2024
a55efde
unify the API for torch_base and getting necessary ModelMember metadata
rileyjmurray Feb 6, 2024
a8f6145
remove old comments and unused imports. Style tweaks.
rileyjmurray Feb 6, 2024
e72dbad
formally declare the stateless_data and torch_base functions in the M…
rileyjmurray Feb 6, 2024
d2c8d38
reenable commented-out tests in test_forwardsim.py
rileyjmurray Feb 7, 2024
2435a50
gracefully handle when pytorch is not installed
rileyjmurray Feb 7, 2024
2e4c3cf
stash
rileyjmurray Feb 15, 2024
a3ffa68
better workaround for circular imports in type annotations
rileyjmurray May 6, 2024
f5383b9
Create Torchable subclass of ModelMember
rileyjmurray May 7, 2024
ac2e8e7
remove static constant from TorchForwardSimulator class
rileyjmurray May 7, 2024
5a1be5d
docstring changes
rileyjmurray May 7, 2024
1ec6909
docstring changes
rileyjmurray May 7, 2024
957192a
clean up TPState constructor. Add documentation for TPPOVM. Change im…
rileyjmurray May 22, 2024
07537f3
fix handling lack of pytorch
rileyjmurray May 22, 2024
0e28075
add torch to testing requirements in setup.py
rileyjmurray May 30, 2024
b69c9a0
refactor type annotations and definition of TORCH_ENABLED constant.
rileyjmurray May 30, 2024
d94cdce
lots of variable renaming for easier interpretability. Change free_pa…
rileyjmurray May 31, 2024
6116cd9
renaming
rileyjmurray May 31, 2024
1430b51
readability and formatting
rileyjmurray May 31, 2024
0dbd3cb
more refactoring for easier interpretability
rileyjmurray May 31, 2024
17ba1bb
various tweaks to helper classes in torchfwdsim.py. Documentation in …
rileyjmurray May 31, 2024
2b3f68d
actually switch betweenn jacfwd and jacrev depending on the value of …
rileyjmurray May 31, 2024
ce5f02a
leave TODO note
rileyjmurray May 31, 2024
0eb2a07
revert change that complicated constructor of StatelessCircuit (see h…
rileyjmurray May 31, 2024
863211c
remove outcome_probs function from StatelessCircuit
rileyjmurray May 31, 2024
114a639
test for forward simulators thats conceptually more like a unit test …
rileyjmurray Jun 4, 2024
cff841d
cludge fix for ComplementPOVMEffect.to_vector() error
rileyjmurray Jun 4, 2024
f1c7ec5
change the type of exception raised by ComplementPOVMEffect.to_vector()
rileyjmurray Jun 4, 2024
7f4a45f
change exception type
rileyjmurray Jun 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pygsti/evotypes/densitymx_slow/effectreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import numpy as _np

# import functools as _functools
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import matrixtools as _mt


class EffectRep(_basereps.EffectRep):
class EffectRep:
"""Any representation of an "effect" in the sense of a POVM."""

def __init__(self, state_space):
self.state_space = _StateSpace.cast(state_space)

Expand All @@ -27,6 +28,10 @@ def probability(self, state):


class EffectRepConjugatedState(EffectRep):
"""
A real superket representation of an "effect" in the sense of a POVM.
Internally uses a StateRepDense object to hold the real superket.
"""

def __init__(self, state_rep):
self.state_rep = state_rep
Expand Down
17 changes: 15 additions & 2 deletions pygsti/evotypes/densitymx_slow/opreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from scipy.sparse.linalg import LinearOperator

from .statereps import StateRepDense as _StateRepDense
from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import internalgates as _itgs
Expand All @@ -26,7 +25,11 @@
from ...tools import optools as _ot


class OpRep(_basereps.OpRep):
class OpRep:
"""
A real superoperator on Hilbert-Schmidt space.
"""

def __init__(self, state_space):
self.state_space = state_space

Expand All @@ -41,6 +44,10 @@ def adjoint_acton(self, state):
raise NotImplementedError()

def aslinearoperator(self):
"""
Return a SciPy LinearOperator that accepts superket representations of vectors
in Hilbert-Schmidt space and returns a vector of that same representation.
"""
def mv(v):
if v.ndim == 2 and v.shape[1] == 1: v = v[:, 0]
in_state = _StateRepDense(_np.ascontiguousarray(v, 'd'), self.state_space, None)
Expand All @@ -54,6 +61,12 @@ def rmv(v):


class OpRepDenseSuperop(OpRep):
"""
A real superoperator on Hilbert-Schmidt space.
The operator's action (and adjoint action) work with Hermitian matrices
stored as *vectors* in their real superket representations.
"""

def __init__(self, mx, basis, state_space):
state_space = _StateSpace.cast(state_space)
if mx is None:
Expand Down
13 changes: 10 additions & 3 deletions pygsti/evotypes/densitymx_slow/statereps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import numpy as _np

from .. import basereps as _basereps
from pygsti.baseobjs.statespace import StateSpace as _StateSpace
from ...tools import basistools as _bt
from ...tools import optools as _ot
Expand All @@ -25,13 +24,17 @@
_fastcalc = None


class StateRep(_basereps.StateRep):
class StateRep:
"""A real superket representation of an element in Hilbert-Schmidt space."""

def __init__(self, data, state_space):
#vec = _np.asarray(vec, dtype='d')
assert(data.dtype == _np.dtype('d'))
self.data = _np.require(data.copy(), requirements=['OWNDATA', 'C_CONTIGUOUS'])
self.state_space = _StateSpace.cast(state_space)
assert(len(self.data) == self.state_space.dim)
ds0 = self.data.shape[0]
assert(ds0 == self.state_space.dim)
assert(ds0 == self.data.size)

def __reduce__(self):
return (StateRep, (self.data, self.state_space), (self.data.flags.writeable,))
Expand Down Expand Up @@ -62,6 +65,10 @@ def __str__(self):


class StateRepDense(StateRep):
"""
An almost-trivial wrapper around StateRep.
Implements the "base" property and defines a trivial "base_has_changed" function.
"""

def __init__(self, data, state_space, basis):
#ignore basis for now (self.basis = basis in future?)
Expand Down
1 change: 1 addition & 0 deletions pygsti/forwardsims/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .forwardsim import ForwardSimulator
from .mapforwardsim import SimpleMapForwardSimulator, MapForwardSimulator
from .torchfwdsim import TorchForwardSimulator
from .matrixforwardsim import SimpleMatrixForwardSimulator, MatrixForwardSimulator
from .termforwardsim import TermForwardSimulator
from .weakforwardsim import WeakForwardSimulator
2 changes: 1 addition & 1 deletion pygsti/forwardsims/forwardsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def create_layout(self, circuits, dataset=None, resource_alloc=None,
if 'epp' in array_types:
derivative_dimensions = (self.model.num_params, self.model.num_params)
elif 'ep' in array_types:
derivative_dimensions = (self.model.num_params)
derivative_dimensions = (self.model.num_params,)
else:
derivative_dimensions = tuple()
return _CircuitOutcomeProbabilityArrayLayout.create_from(circuits, self.model, dataset, derivative_dimensions,
Expand Down
264 changes: 264 additions & 0 deletions pygsti/forwardsims/torchfwdsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""
Defines a ForwardSimulator class called "TorchForwardSimulator" that can leverage the automatic
differentation features of PyTorch.

This file also defines two helper classes: StatelessCircuit and StatelessModel.

See also: pyGSTi/modelmembers/torchable.py.
"""
#***************************************************************************************************
# Copyright 2024, National Technology & Engineering Solutions of Sandia, LLC (NTESS).
# Under the terms of Contract DE-NA0003525 with NTESS, the U.S. Government retains certain rights
# in this software.
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0 or in the LICENSE file in the root pyGSTi directory.
#***************************************************************************************************


from __future__ import annotations
from typing import Tuple, Optional, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from pygsti.baseobjs.label import Label
from pygsti.models.explicitmodel import ExplicitOpModel
from pygsti.circuits.circuit import SeparatePOVMCircuit
from pygsti.layouts.copalayout import CircuitOutcomeProbabilityArrayLayout
import torch

import warnings as warnings
from pygsti.modelmembers.torchable import Torchable
from pygsti.forwardsims.forwardsim import ForwardSimulator

try:
import torch
TORCH_ENABLED = True
except ImportError:
TORCH_ENABLED = False
pass


class StatelessCircuit:
"""
Helper data structure for specifying a quantum circuit (consisting of prep,
applying a sequence of gates, and applying a POVM to the output of the last gate).
"""

def __init__(self, spc: SeparatePOVMCircuit):
self.prep_label = spc.circuit_without_povm[0]
self.op_labels = spc.circuit_without_povm[1:]
self.povm_label = spc.povm_label
self.outcome_probs_dim = len(spc.effect_labels)
# ^ This definition of outcome_probs_dim will need to be changed if/when
# we extend any Instrument class to be Torchable.
return


class StatelessModel:
"""
A container for the information in an ExplicitOpModel that's "stateless" in the sense of
object-oriented programming:

* A list of StatelessCircuits
* Metadata for parameterized ModelMembers

StatelessModels have instance functions to facilitate computation of (differentable!)
circuit outcome probabilities.

Design notes
------------
Much of this functionality could be packed into the TorchForwardSimulator class.
Keeping it separate from TorchForwardSimulator helps clarify that it uses none of
the sophiciated machinery in TorchForwardSimulator's base class.
"""

def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArrayLayout):
circuits = []
self.outcome_probs_dim = 0
for _, circuit, outcomes in layout.iter_unique_circuits():
expanded_circuits = circuit.expand_instruments_and_separate_povm(model, outcomes)
if len(expanded_circuits) > 1:
raise NotImplementedError("I don't know what to do with this.")
spc = next(iter(expanded_circuits))
c = StatelessCircuit(spc)
circuits.append(c)
self.outcome_probs_dim += c.outcome_probs_dim
self.circuits = circuits

# We need to verify assumptions on what layout.iter_unique_circuits() returns.
# Looking at the implementation of that function, the assumptions can be
# framed in terms of the "layout._element_indicies" dict.
eind = layout._element_indices
assert isinstance(eind, dict)
items = iter(eind.items())
k_prev, v_prev = next(items)
assert k_prev == 0
assert v_prev.start == 0
for k, v in items:
assert k == k_prev + 1
assert v.start == v_prev.stop
k_prev = k
v_prev = v
assert self.outcome_probs_dim == v_prev.stop

self.param_metadata = []
for lbl, obj in model._iter_parameterized_objs():
assert isinstance(obj, Torchable), f"{type(obj)} does not subclass {Torchable}."
param_type = type(obj)
param_data = (lbl, param_type) + (obj.stateless_data(),)
self.param_metadata.append(param_data)
self.params_dim = None
# ^ That's set in get_free_params.

self.default_to_reverse_ad = None
# ^ That'll be set to a boolean the next time that get_free_params is called.
return

def get_free_params(self, model: ExplicitOpModel) -> Tuple[torch.Tensor]:
"""
Return a tuple of Tensors that encode the states of the provided model's ModelMembers
(where "state" in meant the sense of object-oriented programming).

We compare the labels of the input model's ModelMembers to those of the model provided
to StatelessModel.__init__(...). We raise an error if an inconsistency is detected.
"""
free_params = []
prev_idx = 0
for i, (lbl, obj) in enumerate(model._iter_parameterized_objs()):
gpind = obj.gpindices_as_array()
vec = obj.to_vector()
vec_size = vec.size
vec = torch.from_numpy(vec)
assert gpind[0] == prev_idx and gpind[-1] == prev_idx + vec_size - 1
# ^ We should have gpind = (prev_idx, prev_idx + 1, ..., prev_idx + vec.size - 1).
# That assert checks a cheap necessary condition that this holds.
prev_idx += vec_size
if self.param_metadata[i][0] != lbl:
message = """
The model passed to get_free_params has a qualitatively different structure from
the model used to construct this StatelessModel. Specifically, the two models have
qualitative differences in the output of "model._iter_parameterized_objs()".

The presence of this structral difference essentially gaurantees that a subsequent
call to get_torch_bases would silently fail, so we're forced to raise an error here.
"""
raise ValueError(message)
free_params.append(vec)
self.params_dim = prev_idx
self.default_to_reverse_ad = self.outcome_probs_dim < self.params_dim
return tuple(free_params)

def get_torch_bases(self, free_params: Tuple[torch.Tensor]) -> Dict[Label, torch.Tensor]:
"""
Take data of the kind produced by get_free_params and format it in the way required by
circuit_probs_from_torch_bases.

Note
----
If you want to use the returned dict to build a PyTorch Tensor that supports the
.backward() method, then you need to make sure that fp.requires_grad is True for all
fp in free_params. This can be done by calling fp._requires_grad(True) before calling
this function.
"""
assert len(free_params) == len(self.param_metadata)
# ^ A sanity check that we're being called with the correct number of arguments.
torch_bases = dict()
for i, val in enumerate(free_params):

label, type_handle, stateless_data = self.param_metadata[i]
param_t = type_handle.torch_base(stateless_data, val)
torch_bases[label] = param_t

return torch_bases

def circuit_probs_from_torch_bases(self, torch_bases: Dict[Label, torch.Tensor]) -> torch.Tensor:
"""
Compute the circuit outcome probabilities that result when all of this StatelessModel's
StatelessCircuits are run with data in torch_bases.

Return the results as a single (vectorized) torch Tensor.
"""
probs = []
for c in self.circuits:
superket = torch_bases[c.prep_label]
superops = [torch_bases[ol] for ol in c.op_labels]
povm_mat = torch_bases[c.povm_label]
for superop in superops:
superket = superop @ superket
circuit_probs = povm_mat @ superket
probs.append(circuit_probs)
probs = torch.concat(probs)
return probs

def circuit_probs_from_free_params(self, *free_params: Tuple[torch.Tensor], enable_backward=False) -> torch.Tensor:
"""
This is the basic function we expose to pytorch for automatic differentiation. It returns the circuit
outcome probabilities resulting when the states of ModelMembers associated with this StatelessModel
are set based on free_params.

If you want to call PyTorch's .backward() on the returned Tensor (or a function of that Tensor), then
you should set enable_backward=True. Keep the default value of enable_backward=False in all other
situations, including when using PyTorch's jacrev function.
"""
if enable_backward:
for fp in free_params:
fp._requires_grad(True)
torch_bases = self.get_torch_bases(free_params)
probs = self.circuit_probs_from_torch_bases(torch_bases)
return probs


class TorchForwardSimulator(ForwardSimulator):
"""
A forward simulator that leverages automatic differentiation in PyTorch.
"""

ENABLED = TORCH_ENABLED

def __init__(self, model : Optional[ExplicitOpModel] = None):
if not self.ENABLED:
raise RuntimeError('PyTorch could not be imported.')
self.model = model
super(ForwardSimulator, self).__init__(model)

def _bulk_fill_probs(self, array_to_fill, layout, split_model = None) -> None:
if split_model is None:
slm = StatelessModel(self.model, layout)
free_params = slm.get_free_params(self.model)
torch_bases = slm.get_torch_bases(free_params)
else:
slm, torch_bases = split_model

probs = slm.circuit_probs_from_torch_bases(torch_bases)
array_to_fill[:slm.outcome_probs_dim] = probs.cpu().detach().numpy().ravel()
return

def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None:
slm = StatelessModel(self.model, layout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a fresh StatelessModel for every jacobian call sounds expensive. For _bulk_fill_probs you had an optional stripped_abstractions argument that allowed for reuse of a previously generated one, does that not work here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny note up front: I've renamed the variable that you're referencing asstripped_abstractions. It's now called split_model. Now to answer your question ...

For _bulk_fill_probs (no "d") I was able to add an optional argument since there's a codepath in _bulk_fill_dprobs (with "d") that requires calling _bulk_fill_probs. In order to make use of an extra argument here we'd need to override the functions in the ForwardSimulator base class that call this function. I've added a TODO for us to explore this later.

# ^ TODO: figure out how to safely recycle StatelessModel objects from one
# call to another. The current implementation is wasteful if we need to
# compute many jacobians without structural changes to layout or self.model.
free_params = slm.get_free_params(self.model)

if pr_array_to_fill is not None:
torch_bases = slm.get_torch_bases(free_params)
splitm = (slm, torch_bases)
self._bulk_fill_probs(pr_array_to_fill, layout, splitm)

argnums = tuple(range(len(slm.param_metadata)))
if slm.default_to_reverse_ad:
# Then slm.circuit_probs_from_free_params will automatically construct the
# torch_base dict to support reverse-mode AD.
J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums)
else:
# Then slm.circuit_probs_from_free_params will automatically skip the extra
# steps needed for torch_base to support reverse-mode AD.
J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums)
# ^ Note that this _bulk_fill_dprobs function doesn't accept parameters that
# could be used to override the default behavior of the StatelessModel. If we
# have a need to override the default in the future then we'd need to override
# the ForwardSimulator function(s) that call self._bulk_fill_dprobs(...).

J_val = J_func(*free_params)
J_val = torch.column_stack(J_val)
array_to_fill[:] = J_val.cpu().detach().numpy()
return
Loading