2020from pygsti import tools as _tools
2121from pygsti .tools import mpitools as _mpit
2222from pygsti .tools import slicetools as _slct
23- from pygsti .models .gaugegroup import TrivialGaugeGroupElement as _TrivialGaugeGroupElement
23+ from pygsti .models .gaugegroup import (
24+ TrivialGaugeGroupElement as _TrivialGaugeGroupElement ,
25+ GaugeGroupElement as _GaugeGroupElement
26+ )
27+
28+ from typing import Callable , Union , Optional
2429
2530
2631def gaugeopt_to_target (model , target_model , item_weights = None ,
@@ -29,7 +34,7 @@ def gaugeopt_to_target(model, target_model, item_weights=None,
2934 gauge_group = None , method = 'auto' , maxiter = 100000 ,
3035 maxfev = None , tol = 1e-8 , oob_check_interval = 0 ,
3136 convert_model_to = None , return_all = False , comm = None ,
32- verbosity = 0 , check_jac = False ):
37+ verbosity = 0 , check_jac = False , n_leak = 0 ):
3338 """
3439 Optimize the gauge degrees of freedom of a model to that of a target.
3540
@@ -170,7 +175,7 @@ def gaugeopt_to_target(model, target_model, item_weights=None,
170175 objective_fn , jacobian_fn = _create_objective_fn (
171176 model , target_model , item_weights ,
172177 cptp_penalty_factor , spam_penalty_factor ,
173- gates_metric , spam_metric , method , comm , check_jac )
178+ gates_metric , spam_metric , method , comm , check_jac , n_leak )
174179
175180 result = gaugeopt_custom (model , objective_fn , gauge_group , method ,
176181 maxiter , maxfev , tol , oob_check_interval ,
@@ -307,9 +312,6 @@ def _call_jacobian_fn(gauge_group_el_vec):
307312
308313 printer .log ("--- Gauge Optimization (%s method, %s) ---" % (method , str (type (gauge_group ))), 2 )
309314 if method == 'ls' :
310- #minSol = _opt.least_squares(_call_objective_fn, x0, #jac=_call_jacobian_fn,
311- # max_nfev=maxfev, ftol=tol)
312- #solnX = minSol.x
313315 assert (_call_jacobian_fn is not None ), "Cannot use 'ls' method unless jacobian is available"
314316 ralloc = _baseobjs .ResourceAllocation (comm ) # FUTURE: plumb up a resource alloc object?
315317 test_f = _call_objective_fn (x0 )
@@ -354,10 +356,15 @@ def _call_jacobian_fn(gauge_group_el_vec):
354356 return newModel
355357
356358
357- def _create_objective_fn (model , target_model , item_weights = None ,
358- cptp_penalty_factor = 0 , spam_penalty_factor = 0 ,
359+ GGElObjective = Callable [[_GaugeGroupElement ,bool ], Union [float , _np .ndarray ]]
360+
361+ GGElJacobian = Union [None , Callable [[_GaugeGroupElement ], _np .ndarray ]]
362+
363+
364+ def _create_objective_fn (model , target_model , item_weights : Optional [dict [str ,float ]]= None ,
365+ cptp_penalty_factor : float = 0.0 , spam_penalty_factor : float = 0.0 ,
359366 gates_metric = "frobenius" , spam_metric = "frobenius" ,
360- method = None , comm = None , check_jac = False ) :
367+ method = None , comm = None , check_jac = False , n_leak = 0 ) -> tuple [ GGElObjective , GGElJacobian ] :
361368 """
362369 Creates the objective function and jacobian (if available)
363370 for gaugeopt_to_target
@@ -595,17 +602,32 @@ def _mock_objective_fn(v):
595602 # non-least-squares case where objective function returns a single float
596603 # and (currently) there's no analytic jacobian
597604
605+ assert gates_metric != "frobeniustt"
606+ assert spam_metric != "frobeniustt"
607+ # ^ PR #410 removed support for Frobenius transform-target metrics in this codepath.
608+
609+ dim = int (_np .sqrt (mxBasis .dim ))
610+ if n_leak > 0 :
611+ B = _tools .leading_dxd_submatrix_basis_vectors (dim - n_leak , dim , mxBasis )
612+ P = B @ B .T .conj ()
613+ if _np .linalg .norm (P .imag ) > 1e-12 :
614+ msg = f"Attempting to run leakage-aware gauge optimization with basis { mxBasis } \n "
615+ msg += "is resulting an orthogonal projector onto the computational subspace that\n "
616+ msg += "is not real-valued. Try again with a different basis, like 'l2p1' or 'gm'."
617+ raise ValueError (msg )
618+ else :
619+ P = P .real
620+ transform_mx_arg = (P , _tools .matrixtools .IdentityOperator ())
621+ # ^ The semantics of this tuple are defined by the frobeniusdist function
622+ # in the ExplicitOpModelCalc class.
623+ else :
624+ transform_mx_arg = None
625+ # ^ It would be equivalent to set this to a pair of IdentityOperator objects.
626+
598627 def _objective_fn (gauge_group_el , oob_check ):
599628 mdl = _transform_with_oob_check (model , gauge_group_el , oob_check )
600629 ret = 0
601630
602- if gates_metric == "frobeniustt" or spam_metric == "frobeniustt" :
603- full_target_model = target_model .copy ()
604- full_target_model .convert_members_inplace ("full" ) # so we can gauge-transform the target model.
605- transformed_target = _transform_with_oob_check (full_target_model , gauge_group_el .inverse (), oob_check )
606- else :
607- transformed_target = None
608-
609631 if cptp_penalty_factor > 0 :
610632 mdl .basis = mxBasis # set basis for jamiolkowski iso
611633 cpPenaltyVec = _cptp_penalty (mdl , cptp_penalty_factor , mdl .basis )
@@ -616,84 +638,86 @@ def _objective_fn(gauge_group_el, oob_check):
616638 spamPenaltyVec = _spam_penalty (mdl , spam_penalty_factor , mdl .basis )
617639 ret += _np .sum (spamPenaltyVec )
618640
619- if target_model is not None :
620- if gates_metric == "frobenius" :
621- if spam_metric == "frobenius" :
622- ret += mdl .frobeniusdist (target_model , None , item_weights )
623- else :
624- wts = item_weights .copy (); wts ['spam' ] = 0.0
625- for k in wts :
626- if k in mdl .preps or \
627- k in mdl .povms : wts [k ] = 0.0
628- ret += mdl .frobeniusdist (target_model , None , wts )
629-
630- elif gates_metric == "frobeniustt" :
631- if spam_metric == "frobeniustt" :
632- ret += transformed_target .frobeniusdist (model , None , item_weights )
633- else :
634- wts = item_weights .copy (); wts ['spam' ] = 0.0
635- for k in wts :
636- if k in mdl .preps or \
637- k in mdl .povms : wts [k ] = 0.0
638- ret += transformed_target .frobeniusdist (model , None , wts )
639-
640- elif gates_metric == "fidelity" :
641- for opLbl in mdl .operations :
642- wt = item_weights .get (opLbl , opWeight )
643- ret += wt * (1.0 - _tools .entanglement_fidelity (
644- target_model .operations [opLbl ], mdl .operations [opLbl ]))** 2
645-
646- elif gates_metric == "tracedist" :
647- for opLbl in mdl .operations :
648- wt = item_weights .get (opLbl , opWeight )
649- ret += opWeight * _tools .jtracedist (
650- target_model .operations [opLbl ], mdl .operations [opLbl ])
651-
652- else : raise ValueError ("Invalid gates_metric: %s" % gates_metric )
653-
654- if spam_metric == "frobenius" :
655- if gates_metric != "frobenius" : # otherwise handled above to match normalization in frobeniusdist
656- wts = item_weights .copy (); wts ['gates' ] = 0.0
657- for k in wts :
658- if k in mdl .operations or \
659- k in mdl .instruments : wts [k ] = 0.0
660- ret += mdl .frobeniusdist (target_model , None , wts )
661-
662- elif spam_metric == "frobeniustt" :
663- if gates_metric != "frobeniustt" : # otherwise handled above to match normalization in frobeniusdist
664- wts = item_weights .copy (); wts ['gates' ] = 0.0
665- for k in wts :
666- if k in mdl .operations or \
667- k in mdl .instruments : wts [k ] = 0.0
668- ret += transformed_target .frobeniusdist (model , None , wts )
669-
670- elif spam_metric == "fidelity" :
671- for preplabel , prep in mdl .preps .items ():
672- wt = item_weights .get (preplabel , spamWeight )
673- rhoMx1 = _tools .vec_to_stdmx (prep , mxBasis )
674- rhoMx2 = _tools .vec_to_stdmx (
675- target_model .preps [preplabel ], mxBasis )
676- ret += wt * (1.0 - _tools .fidelity (rhoMx1 , rhoMx2 ))** 2
677-
678- for povmlabel , povm in mdl .povms .items ():
679- wt = item_weights .get (povmlabel , spamWeight )
680- ret += wt * (1.0 - _tools .povm_fidelity (
681- mdl , target_model , povmlabel ))** 2
682-
683- elif spam_metric == "tracedist" :
684- for preplabel , prep in mdl .preps .items ():
685- wt = item_weights .get (preplabel , spamWeight )
686- rhoMx1 = _tools .vec_to_stdmx (prep , mxBasis )
687- rhoMx2 = _tools .vec_to_stdmx (
688- target_model .preps [preplabel ], mxBasis )
689- ret += wt * _tools .tracedist (rhoMx1 , rhoMx2 )
690-
691- for povmlabel , povm in mdl .povms .items ():
692- wt = item_weights .get (povmlabel , spamWeight )
693- ret += wt * (1.0 - _tools .povm_jtracedist (
694- mdl , target_model , povmlabel ))** 2
695-
696- else : raise ValueError ("Invalid spam_metric: %s" % spam_metric )
641+ if target_model is None :
642+ return ret
643+
644+ if "frobenius" in gates_metric :
645+ if spam_metric == gates_metric :
646+ val = mdl .frobeniusdist (target_model , transform_mx_arg , item_weights )
647+ else :
648+ wts = item_weights .copy ()
649+ wts ['spam' ] = 0.0
650+ for k in wts :
651+ if k in mdl .preps or k in mdl .povms :
652+ wts [k ] = 0.0
653+ val = mdl .frobeniusdist (target_model , transform_mx_arg , wts , n_leak )
654+ if "squared" in gates_metric :
655+ val = val ** 2
656+ ret += val
657+
658+ elif gates_metric == "fidelity" :
659+ # If n_leak==0, then subspace_entanglement_fidelity is just entanglement_fidelity
660+ for opLbl in mdl .operations :
661+ wt = item_weights .get (opLbl , opWeight )
662+ top = target_model .operations [opLbl ].to_dense ()
663+ mop = mdl .operations [opLbl ].to_dense ()
664+ ret += wt * (1.0 - _tools .subspace_entanglement_fidelity (top , mop , mxBasis , n_leak ))** 2
665+
666+ elif gates_metric == "tracedist" :
667+ # If n_leak==0, then subspace_jtracedist is just jtracedist.
668+ for opLbl in mdl .operations :
669+ wt = item_weights .get (opLbl , opWeight )
670+ top = target_model .operations [opLbl ].to_dense ()
671+ mop = mdl .operations [opLbl ].to_dense ()
672+ ret += wt * _tools .subspace_jtracedist (top , mop , mxBasis , n_leak )
673+
674+ else :
675+ raise ValueError ("Invalid gates_metric: %s" % gates_metric )
676+
677+ if "frobenius" in spam_metric and gates_metric == spam_metric :
678+ # We already handled SPAM error in this case. Just return.
679+ return ret
680+
681+ if "frobenius" in spam_metric :
682+ # SPAM and gates can have different choices for squared vs non-squared.
683+ wts = item_weights .copy (); wts ['gates' ] = 0.0
684+ for k in wts :
685+ if k in mdl .operations or k in mdl .instruments :
686+ wts [k ] = 0.0
687+ val = mdl .frobeniusdist (target_model , transform_mx_arg , wts )
688+ if "squared" in spam_metric :
689+ val = val ** 2
690+ ret += val
691+
692+ elif spam_metric == "fidelity" :
693+ # Leakage-aware metrics NOT available
694+ for preplabel , m_prep in mdl .preps .items ():
695+ wt = item_weights .get (preplabel , spamWeight )
696+ rhoMx1 = _tools .vec_to_stdmx (m_prep .to_dense (), mxBasis )
697+ t_prep = target_model .preps [preplabel ]
698+ rhoMx2 = _tools .vec_to_stdmx (t_prep .to_dense (), mxBasis )
699+ ret += wt * (1.0 - _tools .fidelity (rhoMx1 , rhoMx2 ))** 2
700+
701+ for povmlabel in mdl .povms .keys ():
702+ wt = item_weights .get (povmlabel , spamWeight )
703+ fidelity = _tools .povm_fidelity (mdl , target_model , povmlabel )
704+ ret += wt * (1.0 - fidelity )** 2
705+
706+ elif spam_metric == "tracedist" :
707+ # Leakage-aware metrics NOT available.
708+ for preplabel , m_prep in mdl .preps .items ():
709+ wt = item_weights .get (preplabel , spamWeight )
710+ rhoMx1 = _tools .vec_to_stdmx (m_prep .to_dense (), mxBasis )
711+ t_prep = target_model .preps [preplabel ]
712+ rhoMx2 = _tools .vec_to_stdmx (t_prep .to_dense (), mxBasis )
713+ ret += wt * _tools .tracedist (rhoMx1 , rhoMx2 )
714+
715+ for povmlabel in mdl .povms .keys ():
716+ wt = item_weights .get (povmlabel , spamWeight )
717+ ret += wt * _tools .povm_jtracedist (mdl , target_model , povmlabel )
718+
719+ else :
720+ raise ValueError ("Invalid spam_metric: %s" % spam_metric )
697721
698722 return ret
699723
0 commit comments