Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run product state functions inplace to avoid copies where possible #6396

Merged
merged 4 commits into from
Feb 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Run product state merges inplace to avoid copies
  • Loading branch information
daxfohl committed Jan 6, 2024
commit c2344b2ed08c4dbcb8e196ffb4a82e9d6b589147
18 changes: 11 additions & 7 deletions cirq-core/cirq/sim/simulation_product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,16 @@ def split_untangled_states(self) -> bool:
return self._split_untangled_states

def create_merged_state(self) -> TSimulationState:
final_state = self.sim_states[None]
if not self.split_untangled_states:
return self.sim_states[None]
final_args = self.sim_states[None]
for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]):
final_args = final_args.kronecker_product(args)
return final_args.transpose_to_qubit_order(self.qubits)
return final_state
extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None])
if not extra_states:
return final_state
final_state = final_state.copy(deep_copy_buffers=False)
for state in extra_states:
final_state.kronecker_product(state, inplace=True)
return final_state.transpose_to_qubit_order(self.qubits, inplace=True)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down Expand Up @@ -106,7 +110,7 @@ def _act_on_fallback_(
if op_args_opt is None:
op_args_opt = self.sim_states[q]
elif q not in op_args_opt.qubits:
op_args_opt = op_args_opt.kronecker_product(self.sim_states[q])
op_args_opt.kronecker_product(self.sim_states[q], inplace=True)
op_args = op_args_opt or self.sim_states[None]

# (Backfill the args map with the new value)
Expand All @@ -123,7 +127,7 @@ def _act_on_fallback_(
):
for q in qubits:
if op_args.allows_factoring and len(op_args.qubits) > 1:
q_args, op_args = op_args.factor((q,), validate=False)
q_args, _ = op_args.factor((q,), validate=False, inplace=True)
self._sim_states[q] = q_args

# (Backfill the args map with the new value)
Expand Down