Skip to content

Commit 85f43c8

Browse files
committed
improve performance
direct write data to GPU memory if already allocated
1 parent 7da6b9f commit 85f43c8

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

qmlant/neural_networks/estimator_tn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ def prepare_circuit(
3131
expr: str,
3232
operands: list[cp.ndarray],
3333
) -> tuple[str, list[cp.ndarray]]:
34-
"""prepare a circuit for forward process setting batch and parameters to the circuit.
35-
"""
34+
"""prepare a circuit for forward process setting batch and parameters to the circuit."""
3635

3736
pname2theta_list = {
3837
f"x[{i}]": batch[:, i].flatten().tolist() for i in range(batch.shape[1])

qmlant/neural_networks/neural_network.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66

7-
def Ry(theta: float, xp=cp):
7+
def Ry(theta: float, xp=cp) -> cp.ndarray:
88
cos = xp.cos(theta / 2)
99
sin = xp.sin(theta / 2)
1010
return xp.array(
@@ -13,11 +13,19 @@ def Ry(theta: float, xp=cp):
1313
)
1414

1515

16-
def Ry_Rydag(theta: float, xp=cp):
16+
def Ry_Rydag(theta: float, xp=cp) -> tuple[np.ndarray | cp.ndarray, np.ndarray | cp.ndarray]:
1717
cos = xp.cos(theta / 2)
1818
sin = xp.sin(theta / 2)
1919
ry_rydag = xp.array(
2020
[[[cos, -sin], [sin, cos]], [[cos, sin], [-sin, cos]]],
2121
dtype=complex,
2222
)
2323
return ry_rydag[0], ry_rydag[1]
24+
25+
26+
def Ry_Rydag_direct(theta: float, mat: cp.ndarray, mat_dag: cp.ndarray) -> None:
27+
cos = cp.cos(theta / 2)
28+
sin = cp.sin(theta / 2)
29+
mat[0][0] = mat[1][1] = mat_dag[0][0] = mat_dag[1][1] = cos
30+
mat[0][1] = mat_dag[1][0] = -sin
31+
mat[1][0] = mat_dag[0][1] = sin

qmlant/neural_networks/utils/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from cuquantum import CircuitToEinsum
1010
from qiskit import QuantumCircuit
1111

12-
from ..neural_network import Ry_Rydag
12+
from ..neural_network import Ry_Rydag, Ry_Rydag_direct
1313

1414

1515
@overload
@@ -101,10 +101,8 @@ def replace_ry(
101101
pname2locs: dict[str, tuple[int, int]],
102102
) -> list[cp.ndarray]:
103103
for pname, theta in pname2theta.items(): # e.g. pname[0] = "θ[0]"
104-
ry, ry_dag = Ry_Rydag(theta)
105104
loc, dag_loc = pname2locs[pname]
106-
operands[loc] = ry
107-
operands[dag_loc] = ry_dag
105+
Ry_Rydag_direct(theta, operands[loc], operands[dag_loc])
108106

109107
return operands
110108

@@ -119,10 +117,8 @@ def replace_ry_phase_shift(
119117
# θ[0]: [π/2, -π/2], θ[1]: [π/2, -π/2], ...
120118
for pname, theta in pname2theta.items(): # e.g. pname[0] = "θ[0]"
121119
for phase_shift in phase_shift_list:
122-
ry, ry_dag = Ry_Rydag(theta + phase_shift)
123120
loc, dag_loc = pname2locs[pname]
124-
operands[loc][i, :] = ry
125-
operands[dag_loc][i, :] = ry_dag
121+
Ry_Rydag_direct(theta + phase_shift, operands[loc][i], operands[dag_loc][i])
126122
i += 1
127123

128124
return operands

0 commit comments

Comments
 (0)