Skip to content

Commit 989abad

Browse files
authored
An MPS (approximate) simulator (#3630)
This is based on this algo: https://arxiv.org/abs/2002.07730
1 parent 2a9fff7 commit 989abad

File tree

5 files changed

+577
-0
lines changed

5 files changed

+577
-0
lines changed

cirq/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,10 @@
363363
final_density_matrix,
364364
final_state_vector,
365365
final_wavefunction,
366+
MPSSimulator,
367+
MPSSimulatorStepResult,
368+
MPSState,
369+
MPSTrialResult,
366370
sample,
367371
sample_density_matrix,
368372
sample_state_vector,

cirq/protocols/json_serialization_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ def test_mutually_exclusive_blacklist():
298298
'LinearCombinationOfOperations',
299299
'Linspace',
300300
'ListSweep',
301+
'MPSSimulator',
302+
'MPSSimulatorStepResult',
303+
'MPSState',
304+
'MPSTrialResult',
301305
'NeutralAtomDevice',
302306
'PauliInteractionGate',
303307
'PauliStringPhasor',

cirq/sim/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@
3232
DensityMatrixTrialResult,
3333
)
3434

35+
from cirq.sim.mps_simulator import (
36+
MPSSimulator,
37+
MPSSimulatorStepResult,
38+
MPSState,
39+
MPSTrialResult,
40+
)
41+
3542
from cirq.sim.mux import (
3643
CIRCUIT_LIKE,
3744
final_density_matrix,

cirq/sim/mps_simulator.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# Copyright 2019 The Cirq Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""An MPS simulator.
15+
16+
This is based on this paper:
17+
https://arxiv.org/abs/2002.07730
18+
19+
TODO(tonybruguier): Currently, only linear circuits are handled, while the paper
20+
handles more general topologies.
21+
22+
TODO(tonybruguier): Currently, numpy is used for tensor computations. For speed
23+
switch to QIM for speed.
24+
"""
25+
26+
import collections
27+
import math
28+
from typing import Any, Dict, List, Iterator, Sequence
29+
30+
import numpy as np
31+
32+
import cirq
33+
from cirq import circuits, study, ops, protocols, value
34+
from cirq.sim import simulator
35+
36+
37+
class MPSSimulator(simulator.SimulatesSamples, simulator.SimulatesIntermediateState):
38+
"""An efficient simulator for MPS circuits."""
39+
40+
def __init__(self, seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None):
41+
"""Creates instance of `MPSSimulator`.
42+
43+
Args:
44+
seed: The random seed to use for this simulator.
45+
"""
46+
self.init = True
47+
self._prng = value.parse_random_state(seed)
48+
49+
def _base_iterator(
50+
self, circuit: circuits.Circuit, qubit_order: ops.QubitOrderOrList, initial_state: int
51+
) -> Iterator['cirq.MPSSimulatorStepResult']:
52+
"""Iterator over MPSSimulatorStepResult from Moments of a Circuit
53+
54+
Args:
55+
circuit: The circuit to simulate.
56+
qubit_order: Determines the canonical ordering of the qubits. This
57+
is often used in specifying the initial state, i.e. the
58+
ordering of the computational basis states.
59+
initial_state: The initial state for the simulation in the
60+
computational basis. Represented as a big endian int.
61+
62+
63+
Yields:
64+
MPSStepResult from simulating a Moment of the Circuit.
65+
"""
66+
qubits = ops.QubitOrder.as_qubit_order(qubit_order).order_for(circuit.all_qubits())
67+
68+
qubit_map = {q: i for i, q in enumerate(qubits)}
69+
70+
if len(circuit) == 0:
71+
yield MPSSimulatorStepResult(
72+
measurements={}, state=MPSState(qubit_map, initial_state=initial_state)
73+
)
74+
return
75+
76+
state = MPSState(qubit_map, initial_state=initial_state)
77+
78+
for moment in circuit:
79+
measurements: Dict[str, List[int]] = collections.defaultdict(list)
80+
81+
for op in moment:
82+
if isinstance(op.gate, ops.MeasurementGate):
83+
key = str(protocols.measurement_key(op))
84+
measurements[key].extend(state.perform_measurement(op.qubits, self._prng))
85+
elif protocols.has_unitary(op):
86+
state.apply_unitary(op)
87+
else:
88+
raise NotImplementedError(f"Unrecognized operation: {op!r}")
89+
90+
yield MPSSimulatorStepResult(measurements=measurements, state=state)
91+
92+
def _simulator_iterator(
93+
self,
94+
circuit: circuits.Circuit,
95+
param_resolver: study.ParamResolver,
96+
qubit_order: ops.QubitOrderOrList,
97+
initial_state: int,
98+
) -> Iterator:
99+
"""See definition in `cirq.SimulatesIntermediateState`.
100+
101+
Args:
102+
inital_state: An integer specifying the inital
103+
state in the computational basis.
104+
"""
105+
param_resolver = param_resolver or study.ParamResolver({})
106+
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
107+
self._check_all_resolved(resolved_circuit)
108+
actual_initial_state = 0 if initial_state is None else initial_state
109+
110+
return self._base_iterator(resolved_circuit, qubit_order, actual_initial_state)
111+
112+
def _create_simulator_trial_result(
113+
self,
114+
params: study.ParamResolver,
115+
measurements: Dict[str, np.ndarray],
116+
final_simulator_state,
117+
):
118+
119+
return MPSTrialResult(
120+
params=params, measurements=measurements, final_simulator_state=final_simulator_state
121+
)
122+
123+
def _run(
124+
self, circuit: circuits.Circuit, param_resolver: study.ParamResolver, repetitions: int
125+
) -> Dict[str, List[np.ndarray]]:
126+
127+
param_resolver = param_resolver or study.ParamResolver({})
128+
resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
129+
self._check_all_resolved(resolved_circuit)
130+
131+
measurements = {} # type: Dict[str, List[np.ndarray]]
132+
if repetitions == 0:
133+
for _, op, _ in resolved_circuit.findall_operations_with_gate_type(ops.MeasurementGate):
134+
measurements[protocols.measurement_key(op)] = np.empty([0, 1])
135+
136+
for _ in range(repetitions):
137+
all_step_results = self._base_iterator(
138+
resolved_circuit, qubit_order=ops.QubitOrder.DEFAULT, initial_state=0
139+
)
140+
141+
for step_result in all_step_results:
142+
for k, v in step_result.measurements.items():
143+
if not k in measurements:
144+
measurements[k] = []
145+
measurements[k].append(np.array(v, dtype=int))
146+
147+
return {k: np.array(v) for k, v in measurements.items()}
148+
149+
def _check_all_resolved(self, circuit):
150+
"""Raises if the circuit contains unresolved symbols."""
151+
if protocols.is_parameterized(circuit):
152+
unresolved = [
153+
op for moment in circuit for op in moment if protocols.is_parameterized(op)
154+
]
155+
raise ValueError(
156+
'Circuit contains ops whose symbols were not specified in '
157+
'parameter sweep. Ops: {}'.format(unresolved)
158+
)
159+
160+
161+
class MPSTrialResult(simulator.SimulationTrialResult):
162+
def __init__(
163+
self,
164+
params: study.ParamResolver,
165+
measurements: Dict[str, np.ndarray],
166+
final_simulator_state: 'MPSState',
167+
) -> None:
168+
super().__init__(
169+
params=params, measurements=measurements, final_simulator_state=final_simulator_state
170+
)
171+
172+
self.final_state = final_simulator_state
173+
174+
def __str__(self) -> str:
175+
samples = super().__str__()
176+
final = self._final_simulator_state
177+
return f'measurements: {samples}\noutput state: {final}'
178+
179+
180+
class MPSSimulatorStepResult(simulator.StepResult):
181+
"""A `StepResult` that includes `StateVectorMixin` methods."""
182+
183+
def __init__(self, state, measurements):
184+
"""Results of a step of the simulator.
185+
Attributes:
186+
state: A MPSState
187+
measurements: A dictionary from measurement gate key to measurement
188+
results, ordered by the qubits that the measurement operates on.
189+
qubit_map: A map from the Qubits in the Circuit to the the index
190+
of this qubit for a canonical ordering. This canonical ordering
191+
is used to define the state vector (see the state_vector()
192+
method).
193+
"""
194+
self.measurements = measurements
195+
self.state = state.copy()
196+
197+
def __str__(self) -> str:
198+
def bitstring(vals):
199+
return ','.join(str(v) for v in vals)
200+
201+
results = sorted([(key, bitstring(val)) for key, val in self.measurements.items()])
202+
203+
if len(results) == 0:
204+
measurements = ''
205+
else:
206+
measurements = ' '.join([f'{key}={val}' for key, val in results]) + '\n'
207+
208+
final = self.state
209+
210+
return f'{measurements}{final}'
211+
212+
def _simulator_state(self):
213+
return self.state
214+
215+
def sample(
216+
self,
217+
qubits: List[ops.Qid],
218+
repetitions: int = 1,
219+
seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
220+
) -> np.ndarray:
221+
222+
measurements: List[int] = []
223+
224+
for _ in range(repetitions):
225+
measurements.append(
226+
self.state.perform_measurement(
227+
qubits, value.parse_random_state(seed), collapse_state_vector=False
228+
)
229+
)
230+
231+
return np.array(measurements, dtype=int)
232+
233+
234+
@value.value_equality
235+
class MPSState:
236+
"""A state of the MPS simulation."""
237+
238+
def __init__(self, qubit_map, initial_state=0):
239+
self.qubit_map = qubit_map
240+
self.M = []
241+
for qubit in qubit_map.keys():
242+
d = qubit.dimension
243+
x = np.zeros(
244+
(
245+
1,
246+
1,
247+
d,
248+
)
249+
)
250+
x[0, 0, (initial_state % d)] = 1.0
251+
self.M.append(x)
252+
initial_state = initial_state // d
253+
self.M = self.M[::-1]
254+
self.threshold = 1e-3
255+
256+
def __str__(self) -> str:
257+
return str(self.M)
258+
259+
def _value_equality_values_(self) -> Any:
260+
return self.qubit_map, self.M, self.threshold
261+
262+
def copy(self) -> 'MPSState':
263+
state = MPSState(self.qubit_map)
264+
state.M = [x.copy() for x in self.M]
265+
state.threshold = self.threshold
266+
return state
267+
268+
def state_vector(self):
269+
M = np.ones((1, 1))
270+
for i in range(len(self.M)):
271+
M = np.einsum('ni,npj->pij', M, self.M[i])
272+
M = M.reshape(M.shape[0], -1)
273+
assert M.shape[0] == 1
274+
return M[0, :]
275+
276+
def to_numpy(self) -> np.ndarray:
277+
return self.state_vector()
278+
279+
def apply_unitary(self, op: 'cirq.Operation'):
280+
idx = [self.qubit_map[qubit] for qubit in op.qubits]
281+
U = protocols.unitary(op).reshape([qubit.dimension for qubit in op.qubits] * 2)
282+
283+
if len(idx) == 1:
284+
n = idx[0]
285+
self.M[n] = np.einsum('ij,mnj->mni', U, self.M[n])
286+
elif len(idx) == 2:
287+
n = idx[0]
288+
p = idx[1]
289+
if abs(n - p) != 1:
290+
raise ValueError('Can only handle continguous qubits')
291+
T = np.einsum('klij,mni,npj->mkpl', U, self.M[n], self.M[p])
292+
X, S, Y = np.linalg.svd(T.reshape([T.shape[0] * T.shape[1], T.shape[2] * T.shape[3]]))
293+
X = X.reshape([T.shape[0], T.shape[1], -1])
294+
Y = Y.reshape([-1, T.shape[2], T.shape[3]])
295+
296+
S = np.asarray([math.sqrt(x) for x in S])
297+
298+
nkeep = 0
299+
for i in range(S.shape[0]):
300+
if S[i] >= S[0] * self.threshold:
301+
nkeep = i + 1
302+
303+
X = X[:, :, :nkeep]
304+
S = np.diag(S[:nkeep])
305+
Y = Y[:nkeep, :, :]
306+
307+
self.M[n] = np.einsum('mis,sn->mni', X, S)
308+
self.M[p] = np.einsum('ns,spj->npj', S, Y)
309+
else:
310+
raise ValueError('Can only handle 1 and 2 qubit operations')
311+
312+
def perform_measurement(
313+
self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True
314+
) -> List[int]:
315+
results: List[int] = []
316+
317+
if collapse_state_vector:
318+
state = self
319+
else:
320+
state = self.copy()
321+
322+
for qubit in qubits:
323+
n = state.qubit_map[qubit]
324+
325+
M = np.ones((1, 1))
326+
for i in range(len(state.M)):
327+
if i == n:
328+
M = np.einsum('ni,npj->pij', M, state.M[i])
329+
else:
330+
M = np.einsum('ni,npj->pi', M, state.M[i])
331+
M = M.reshape(M.shape[0], -1)
332+
assert M.shape[0] == 1
333+
M = M.reshape(-1)
334+
probs = [abs(x) ** 2 for x in M]
335+
336+
# Because the computation is approximate, the probabilities do not
337+
# necessarily add up to 1.0, and thus we re-normalize them.
338+
norm_probs = [x / sum(probs) for x in probs]
339+
340+
d = qubit.dimension
341+
result: int = int(prng.choice(d, p=norm_probs))
342+
343+
renormalizer = np.zeros((d, d))
344+
renormalizer[result][result] = 1.0 / math.sqrt(probs[result])
345+
346+
state.M[n] = np.einsum('ij,mnj->mni', renormalizer, state.M[n])
347+
348+
results.append(result)
349+
350+
return results

0 commit comments

Comments
 (0)