diff --git a/src/bloqade/ir/location/bravais.py b/src/bloqade/ir/location/bravais.py index 130215d31..666b59d06 100644 --- a/src/bloqade/ir/location/bravais.py +++ b/src/bloqade/ir/location/bravais.py @@ -45,7 +45,7 @@ def coordinates(self, index: List[int]) -> NDArray: # damn! this is like stone age broadcasting vectors = np.array(self.cell_vectors()) index = np.array(index) - pos = np.sum(index * vectors.T, axis=1) + pos = np.sum(vectors.T * index, axis=1) return pos + np.array(self.cell_atoms()) def enumerate(self) -> Generator[LocationInfo, None, None]: @@ -103,7 +103,7 @@ def __init__( lattice_spacing_y: Optional[Any] = None, ): if lattice_spacing_y is None: - self.ratio = cast(1.0) + self.ratio = cast(1.0) / cast(lattice_spacing_x) else: self.ratio = cast(lattice_spacing_y) / cast(lattice_spacing_x) @@ -125,7 +125,7 @@ def cell_vectors(self) -> List[List[float]]: return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]] def cell_atoms(self) -> List[List[float]]: - return [[0.0, 0.0], [1 / 2, np.sqrt(3) / 2]] + return [[0.0, 0.0], [1 / 2, 1 / (2 * np.sqrt(3))]] @dataclass @@ -163,4 +163,4 @@ def cell_vectors(self) -> List[List[float]]: return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]] def cell_atoms(self) -> List[List[float]]: - return [[0.0, 0.0], [1 / 4, np.sqrt(3) / 4], [3 / 4, np.sqrt(3) / 2]] + return [[0.0, 0.0], [1 / 2, 0], [1 / 4, np.sqrt(3) / 4]] diff --git a/src/bloqade/ir/scalar.py b/src/bloqade/ir/scalar.py index 9332f5e8f..9837298f9 100644 --- a/src/bloqade/ir/scalar.py +++ b/src/bloqade/ir/scalar.py @@ -250,6 +250,8 @@ def trycast(py) -> Any: return tuple(map(cast, xs)) case Scalar(): return py + case tuple() as xs: + return tuple(map(cast, xs)) case _: return diff --git a/tests/test_bravais.py b/tests/test_bravais.py index c7efde89b..055de060e 100644 --- a/tests/test_bravais.py +++ b/tests/test_bravais.py @@ -1,11 +1,11 @@ -from bloqade.ir.location import Square, Rectangular, Honeycomb +from bloqade.ir.location import Lieb, Square, Rectangular, Honeycomb, Kagome, Triangular from bloqade import cast -import pytest +from math import sqrt -# @pytest.mark.skipif(True, reason="Not implemented") def test_square(): lattice = Square(3, lattice_spacing=2.0) + positions = set(map(lambda info: info.position, lattice.enumerate())) positions_expected = set( cast([(0, 0), (0, 2), (0, 4), (2, 0), (2, 2), (2, 4), (4, 0), (4, 2), (4, 4)]) @@ -14,7 +14,6 @@ def test_square(): assert positions == positions_expected -# @pytest.mark.skipif(True, reason="Not implemented") def test_rectangular(): lattice = Rectangular(2, 3, lattice_spacing_x=0.5, lattice_spacing_y=2) positions = set(map(lambda info: info.position, lattice.enumerate())) @@ -24,30 +23,98 @@ def test_rectangular(): assert positions == positions_expected -@pytest.mark.skipif(True, reason="Not implemented") +def test_rectagnular_default_yscale(): + lattice = Rectangular(2, 3, lattice_spacing_x=0.5) + positions = set(map(lambda info: info.position, lattice.enumerate())) + positions_expected = set( + cast([(0, 0), (0, 1.0), (0, 2.0), (0.5, 0), (0.5, 1.0), (0.5, 2.0)]) + ) + assert positions == positions_expected + + +def test_kagome(): + lattice = Kagome(2, lattice_spacing=2.0) + positions = set(map(lambda info: info.position, lattice.enumerate())) + positions_expected = set( + cast( + [ + (0, 0), + (1.0, 0), + (2, 0), + (3, 0), + (0.5, sqrt(3) * 0.5), + (2.5, sqrt(3) * 0.5), + (1, sqrt(3)), + (3, sqrt(3)), + (2, sqrt(3)), + (4, sqrt(3)), + (1.5, sqrt(3) * 1.5), + (3.5, sqrt(3) * 1.5), + ] + ) + ) + + assert positions == positions_expected + + +def test_triangular(): + lattice = Triangular(2, lattice_spacing=2.0) + positions = set(map(lambda info: info.position, lattice.enumerate())) + positions_expected = set(cast([(0, 0), (2, 0), (1.0, sqrt(3)), (3, sqrt(3))])) + + assert positions == positions_expected + + def test_honeycomb(): - lattice = Honeycomb(3, lattice_spacing=2) + lattice = Honeycomb(2, lattice_spacing=2) + positions = set(map(lambda info: info.position, lattice.enumerate())) + positions_expected = set( + cast( + [ + (0.0, 0), + (2.0, 0), + (1.0, 1 / sqrt(3)), + (3.0, 1 / sqrt(3)), + (1.0, sqrt(3)), + (3.0, sqrt(3)), + (2.0, sqrt(3) + 1 / sqrt(3)), + (4.0, sqrt(3) + 1 / sqrt(3)), + ] + ) + ) + + assert positions == positions_expected + + +def test_lieb(): + lattice = Lieb(2, lattice_spacing=2) positions = set(map(lambda info: info.position, lattice.enumerate())) positions_expected = set( cast( [ (0, 0), - (0, 2), - (0, 4), (2, 0), + (0, 2), (2, 2), - (2, 4), - (4, 0), - (4, 2), - (4, 4), - (6, 0), - (6, 2), - (6, 4), - (8, 0), - (8, 2), - (8, 4), + (1, 0), + (3, 0), + (1, 2), + (3, 2), + (0, 1), + (2, 1), + (0, 3), + (2, 3), ] ) ) assert positions == positions_expected + + +def test_scale_lattice(): + lattice = Triangular(2, lattice_spacing=1) + latt2 = lattice.scale(2) + positions = set(map(lambda info: info.position, latt2.enumerate())) + positions_expected = set(cast([(0, 0), (2, 0), (1.0, sqrt(3)), (3, sqrt(3))])) + + assert positions == positions_expected diff --git a/tests/test_builder.py b/tests/test_builder.py index 61921e85d..083d4da7f 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -5,12 +5,103 @@ # for i in range(5): # prog = prog.location(i) # prog.linear(start=1.0, stop=2.0, duration="x") +# import pytest import bloqade.ir as ir -from bloqade import start, var +from bloqade.builder import location +from bloqade.ir import rydberg, detuning +from bloqade import start, var, cast from bloqade.ir.location import Square, Chain import numpy as np +def test_piecewise_const(): + prog = start.rydberg.detuning.uniform.piecewise_constant( + durations=[0.1, 3.1, 0.05], values=[4, 4, 7.5] + ) + + ## inspect ir + node1 = prog + ir1 = node1._waveform + assert ir1.value == cast(7.5) + assert ir1.duration == cast(0.05) + + node2 = node1.__parent__ + ir2 = node2._waveform + assert ir2.value == cast(4) + assert ir2.duration == cast(3.1) + + node3 = node2.__parent__ + ir3 = node3._waveform + assert ir3.value == cast(4) + assert ir3.duration == cast(0.1) + + +def test_registers(): + waveform = ( + ir.Linear("initial_detuning", "initial_detuning", "up_time") + .append(ir.Linear("initial_detuning", "final_detuning", "anneal_time")) + .append(ir.Linear("final_detuning", "final_detuning", "up_time")) + ) + prog1 = start.rydberg.detuning.uniform.apply(waveform) + reg = prog1.register + + assert reg.n_atoms == 0 + assert reg.n_dims is None + + +def test_scale(): + prog = start + prog = ( + prog.rydberg.detuning.location(1) + .scale(1.2) + .piecewise_linear([0.1, 3.8, 0.1], [-10, -10, "a", "b"]) + ) + + ## let Emit build ast + seq = prog.sequence + + print(type(list(seq.value.keys())[0])) + Loc1 = list(seq.value[rydberg].value[detuning].value.keys())[0] + + assert type(Loc1) == ir.ScaledLocations + assert Loc1.value[ir.Location(1)] == cast(1.2) + + +def test_scale_location(): + prog = start.rydberg.detuning.location(1).scale(1.2).location(2).scale(3.3) + + assert prog._scale == cast(3.3) + assert type(prog.__parent__) == location.Location + assert prog.__parent__.__parent__._scale == cast(1.2) + + +def test_build_ast_Scale(): + prog = ( + start.rydberg.detuning.location(1) + .scale(1.2) + .location(2) + .scale(3.3) + .piecewise_constant(durations=[0.1], values=[1]) + ) + + # compile ast: + tmp = prog.sequence + + locs = list(tmp.value[rydberg].value[detuning].value.keys())[0] + wvfm = tmp.value[rydberg].value[detuning].value[locs] + + assert locs == ir.ScaledLocations( + {ir.Location(2): cast(3.3), ir.Location(1): cast(1.2)} + ) + assert wvfm == ir.Constant(value=cast(1), duration=cast(0.1)) + + +def test_spatial_var(): + prog = start.rydberg.detuning.var("a") + + assert prog._name == "a" + + def test_issue_107(): waveform = ( ir.Linear("initial_detuning", "initial_detuning", "up_time") diff --git a/tests/test_builder_factory.py b/tests/test_builder_factory.py new file mode 100644 index 000000000..a5d0894f3 --- /dev/null +++ b/tests/test_builder_factory.py @@ -0,0 +1,55 @@ +from bloqade.builder.factory import ( + piecewise_linear, + piecewise_constant, + constant, + linear, +) +from bloqade import cast + + +def test_ir_piecewise_linear(): + A = piecewise_linear([0.1, 3.8, 0.2], [-10, -7, "a", "b"]) + + ## Append type ir node + assert len(A.waveforms) == 3 + assert A.waveforms[0].duration == cast(0.1) + assert A.waveforms[0].start == cast(-10) + assert A.waveforms[0].stop == cast(-7) + + assert A.waveforms[1].duration == cast(3.8) + assert A.waveforms[1].start == cast(-7) + assert A.waveforms[1].stop == cast("a") + + assert A.waveforms[2].duration == cast(0.2) + assert A.waveforms[2].start == cast("a") + assert A.waveforms[2].stop == cast("b") + + +def test_ir_const(): + A = constant(value=3.415, duration=0.55) + + ## Constant type ir node: + assert A.value == cast(3.415) + assert A.duration == cast(0.55) + + +def test_ir_linear(): + A = linear(start=0.5, stop=3.2, duration=0.76) + + ## Linear type ir node: + assert A.start == cast(0.5) + assert A.stop == cast(3.2) + assert A.duration == cast(0.76) + + +def test_ir_piecewise_constant(): + A = piecewise_constant(durations=[0.1, 3.8, 0.2], values=[-10, "a", "b"]) + + assert A.waveforms[0].duration == cast(0.1) + assert A.waveforms[0].value == cast(-10) + + assert A.waveforms[1].duration == cast(3.8) + assert A.waveforms[1].value == cast("a") + + assert A.waveforms[2].duration == cast(0.2) + assert A.waveforms[2].value == cast("b")