1818import functools
1919from typing import (
2020 Any ,
21- Callable ,
2221 Tuple ,
2322 Hashable ,
2423 List ,
25- Type ,
2624 overload ,
25+ Protocol ,
26+ Type ,
2727 TYPE_CHECKING ,
28+ TypeVar ,
2829)
2930import dataclasses
3031import enum
31- from cirq .circuits .circuit import CIRCUIT_TYPE
3232
3333if TYPE_CHECKING :
3434 import cirq
@@ -218,77 +218,41 @@ class TransformerContext:
218218 ignore_tags : Tuple [Hashable , ...] = ()
219219
220220
221- TRANSFORMER = Callable [['cirq.AbstractCircuit' , TransformerContext ], 'cirq.AbstractCircuit' ]
222- _TRANSFORMER_TYPE = Callable [['cirq.AbstractCircuit' , TransformerContext ], CIRCUIT_TYPE ]
223-
224-
225- def _transform_and_log (
226- func : _TRANSFORMER_TYPE [CIRCUIT_TYPE ],
227- transformer_name : str ,
228- circuit : 'cirq.AbstractCircuit' ,
229- context : TransformerContext ,
230- ) -> CIRCUIT_TYPE :
231- """Helper to log initial and final circuits before and after calling the transformer."""
232-
233- context .logger .register_initial (circuit , transformer_name )
234- transformed_circuit = func (circuit , context )
235- context .logger .register_final (transformed_circuit , transformer_name )
236- return transformed_circuit
237-
238-
239- def _transformer_class (
240- cls : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
241- ) -> Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]]:
242- old_func = cls .__call__
243-
244- def transformer_with_logging_cls (
245- self : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
246- circuit : 'cirq.AbstractCircuit' ,
247- context : TransformerContext ,
248- ) -> CIRCUIT_TYPE :
249- def call_old_func (c : 'cirq.AbstractCircuit' , ct : TransformerContext ) -> CIRCUIT_TYPE :
250- return old_func (self , c , ct )
251-
252- return _transform_and_log (call_old_func , cls .__name__ , circuit , context )
221+ class TRANSFORMER (Protocol ):
222+ def __call__ (
223+ self , circuit : 'cirq.AbstractCircuit' , context : TransformerContext
224+ ) -> 'cirq.AbstractCircuit' :
225+ ...
253226
254- setattr (cls , '__call__' , transformer_with_logging_cls )
255- return cls
256227
257-
258- def _transformer_func (func : _TRANSFORMER_TYPE [CIRCUIT_TYPE ]) -> _TRANSFORMER_TYPE [CIRCUIT_TYPE ]:
259- @functools .wraps (func )
260- def transformer_with_logging_func (
261- circuit : 'cirq.AbstractCircuit' ,
262- context : TransformerContext ,
263- ) -> CIRCUIT_TYPE :
264- return _transform_and_log (func , func .__name__ , circuit , context )
265-
266- return transformer_with_logging_func
228+ _TRANSFORMER_T = TypeVar ('_TRANSFORMER_T' , bound = TRANSFORMER )
229+ _TRANSFORMER_CLS_T = TypeVar ('_TRANSFORMER_CLS_T' , bound = Type [TRANSFORMER ])
267230
268231
269232@overload
270- def transformer (cls_or_func : _TRANSFORMER_TYPE [ CIRCUIT_TYPE ] ) -> _TRANSFORMER_TYPE [ CIRCUIT_TYPE ] :
233+ def transformer (cls_or_func : _TRANSFORMER_T ) -> _TRANSFORMER_T :
271234 pass
272235
273236
274237@overload
275- def transformer (
276- cls_or_func : Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]],
277- ) -> Type [_TRANSFORMER_TYPE [CIRCUIT_TYPE ]]:
238+ def transformer (cls_or_func : _TRANSFORMER_CLS_T ) -> _TRANSFORMER_CLS_T :
278239 pass
279240
280241
281242def transformer (cls_or_func : Any ) -> Any :
282243 """Decorator to verify API and append logging functionality to transformer functions & classes.
283244
284- The decorated function or class must satisfy
285- `Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
245+ A transformer is a callable that takes as inputs a cirq.AbstractCircuit and
246+ cirq.TransformerContext, and returns another cirq.AbstractCircuit without
247+ modifying the input circuit. A transformer could be a function, for example:
286248
287249 >>> @cirq.transformer
288- >>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
250+ >>> def convert_to_cz(
251+ >>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
252+ >>> ) -> cirq.Circuit:
289253 >>> ...
290254
291- The decorated class must implement the `__call__` method to satisfy the above API.
255+ Or it could be a class that implements `__call__` with the same API, for example:
292256
293257 >>> @cirq.transformer
294258 >>> class ConvertToSqrtISwaps:
@@ -300,14 +264,45 @@ def transformer(cls_or_func: Any) -> Any:
300264 >>> ...
301265
302266 Args:
303- cls_or_func: The callable class or method to be decorated.
267+ cls_or_func: The callable class or function to be decorated.
304268
305269 Returns:
306- Decorated class / method which includes additional logging boilerplate. The decorated
307- callable always receives a copy of the input circuit so that the input is never mutated.
270+ Decorated class / function which includes additional logging boilerplate.
308271 """
309272 if isinstance (cls_or_func , type ):
310- return _transformer_class (cls_or_func )
273+ cls = cls_or_func
274+ method = cls .__call__
275+
276+ @functools .wraps (method )
277+ def method_with_logging (self , circuit , context ):
278+ return _transform_and_log (
279+ lambda circuit , context : method (self , circuit , context ),
280+ cls .__name__ ,
281+ circuit ,
282+ context ,
283+ )
284+
285+ setattr (cls , '__call__' , method_with_logging )
286+ return cls
311287 else :
312288 assert callable (cls_or_func )
313- return _transformer_func (cls_or_func )
289+ func = cls_or_func
290+
291+ @functools .wraps (func )
292+ def func_with_logging (circuit , context ):
293+ return _transform_and_log (func , func .__name__ , circuit , context )
294+
295+ return func_with_logging
296+
297+
298+ def _transform_and_log (
299+ func : TRANSFORMER ,
300+ transformer_name : str ,
301+ circuit : 'cirq.AbstractCircuit' ,
302+ context : TransformerContext ,
303+ ) -> 'cirq.AbstractCircuit' :
304+ """Helper to log initial and final circuits before and after calling the transformer."""
305+ context .logger .register_initial (circuit , transformer_name )
306+ transformed_circuit = func (circuit , context )
307+ context .logger .register_final (transformed_circuit , transformer_name )
308+ return transformed_circuit
0 commit comments