Skip to content

Commit

Permalink
Fix bugs on bravais, add more tests, and bump up coverages (#285)
Browse files Browse the repository at this point in the history
* add more tests and enhance cast
1) add tests for scales
2) add tests for bravais
3) to do 2) this commit add cast for tuple type

* tmp

* add more testing cases, bumping to 71%

* fix baravias bugs and more test
1) fix bugs on wrong index*elementary vectors
2) bump testing cases on bravias to 96
3) fix bugs on Rectangular when spacing_y dose not specify

* update to comply ruff

* modify to comply ruff

* modify to comply black

* modify format

* try

* test pre-compile

* pre-commit linter fix

---------

Co-authored-by: Kai-Hsin Wu <khwu@KHWus-MBP.hsd1.ma.comcast.net>
Co-authored-by: Kai-Hsin Wu <khwu@KHWus-MacBook-Pro.local>
  • Loading branch information
3 people authored Jul 13, 2023
1 parent 4b21826 commit 6514f67
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 23 deletions.
8 changes: 4 additions & 4 deletions src/bloqade/ir/location/bravais.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def coordinates(self, index: List[int]) -> NDArray:
# damn! this is like stone age broadcasting
vectors = np.array(self.cell_vectors())
index = np.array(index)
pos = np.sum(index * vectors.T, axis=1)
pos = np.sum(vectors.T * index, axis=1)
return pos + np.array(self.cell_atoms())

def enumerate(self) -> Generator[LocationInfo, None, None]:
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
lattice_spacing_y: Optional[Any] = None,
):
if lattice_spacing_y is None:
self.ratio = cast(1.0)
self.ratio = cast(1.0) / cast(lattice_spacing_x)
else:
self.ratio = cast(lattice_spacing_y) / cast(lattice_spacing_x)

Expand All @@ -125,7 +125,7 @@ def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]

def cell_atoms(self) -> List[List[float]]:
return [[0.0, 0.0], [1 / 2, np.sqrt(3) / 2]]
return [[0.0, 0.0], [1 / 2, 1 / (2 * np.sqrt(3))]]


@dataclass
Expand Down Expand Up @@ -163,4 +163,4 @@ def cell_vectors(self) -> List[List[float]]:
return [[1.0, 0.0], [1 / 2, np.sqrt(3) / 2]]

def cell_atoms(self) -> List[List[float]]:
return [[0.0, 0.0], [1 / 4, np.sqrt(3) / 4], [3 / 4, np.sqrt(3) / 2]]
return [[0.0, 0.0], [1 / 2, 0], [1 / 4, np.sqrt(3) / 4]]
2 changes: 2 additions & 0 deletions src/bloqade/ir/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def trycast(py) -> Any:
return tuple(map(cast, xs))
case Scalar():
return py
case tuple() as xs:
return tuple(map(cast, xs))
case _:
return

Expand Down
103 changes: 85 additions & 18 deletions tests/test_bravais.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from bloqade.ir.location import Square, Rectangular, Honeycomb
from bloqade.ir.location import Lieb, Square, Rectangular, Honeycomb, Kagome, Triangular
from bloqade import cast
import pytest
from math import sqrt


# @pytest.mark.skipif(True, reason="Not implemented")
def test_square():
lattice = Square(3, lattice_spacing=2.0)

positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(
cast([(0, 0), (0, 2), (0, 4), (2, 0), (2, 2), (2, 4), (4, 0), (4, 2), (4, 4)])
Expand All @@ -14,7 +14,6 @@ def test_square():
assert positions == positions_expected


# @pytest.mark.skipif(True, reason="Not implemented")
def test_rectangular():
lattice = Rectangular(2, 3, lattice_spacing_x=0.5, lattice_spacing_y=2)
positions = set(map(lambda info: info.position, lattice.enumerate()))
Expand All @@ -24,30 +23,98 @@ def test_rectangular():
assert positions == positions_expected


@pytest.mark.skipif(True, reason="Not implemented")
def test_rectagnular_default_yscale():
lattice = Rectangular(2, 3, lattice_spacing_x=0.5)
positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(
cast([(0, 0), (0, 1.0), (0, 2.0), (0.5, 0), (0.5, 1.0), (0.5, 2.0)])
)
assert positions == positions_expected


def test_kagome():
lattice = Kagome(2, lattice_spacing=2.0)
positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(
cast(
[
(0, 0),
(1.0, 0),
(2, 0),
(3, 0),
(0.5, sqrt(3) * 0.5),
(2.5, sqrt(3) * 0.5),
(1, sqrt(3)),
(3, sqrt(3)),
(2, sqrt(3)),
(4, sqrt(3)),
(1.5, sqrt(3) * 1.5),
(3.5, sqrt(3) * 1.5),
]
)
)

assert positions == positions_expected


def test_triangular():
lattice = Triangular(2, lattice_spacing=2.0)
positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(cast([(0, 0), (2, 0), (1.0, sqrt(3)), (3, sqrt(3))]))

assert positions == positions_expected


def test_honeycomb():
lattice = Honeycomb(3, lattice_spacing=2)
lattice = Honeycomb(2, lattice_spacing=2)
positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(
cast(
[
(0.0, 0),
(2.0, 0),
(1.0, 1 / sqrt(3)),
(3.0, 1 / sqrt(3)),
(1.0, sqrt(3)),
(3.0, sqrt(3)),
(2.0, sqrt(3) + 1 / sqrt(3)),
(4.0, sqrt(3) + 1 / sqrt(3)),
]
)
)

assert positions == positions_expected


def test_lieb():
lattice = Lieb(2, lattice_spacing=2)
positions = set(map(lambda info: info.position, lattice.enumerate()))
positions_expected = set(
cast(
[
(0, 0),
(0, 2),
(0, 4),
(2, 0),
(0, 2),
(2, 2),
(2, 4),
(4, 0),
(4, 2),
(4, 4),
(6, 0),
(6, 2),
(6, 4),
(8, 0),
(8, 2),
(8, 4),
(1, 0),
(3, 0),
(1, 2),
(3, 2),
(0, 1),
(2, 1),
(0, 3),
(2, 3),
]
)
)

assert positions == positions_expected


def test_scale_lattice():
lattice = Triangular(2, lattice_spacing=1)
latt2 = lattice.scale(2)
positions = set(map(lambda info: info.position, latt2.enumerate()))
positions_expected = set(cast([(0, 0), (2, 0), (1.0, sqrt(3)), (3, sqrt(3))]))

assert positions == positions_expected
93 changes: 92 additions & 1 deletion tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,103 @@
# for i in range(5):
# prog = prog.location(i)
# prog.linear(start=1.0, stop=2.0, duration="x")
# import pytest
import bloqade.ir as ir
from bloqade import start, var
from bloqade.builder import location
from bloqade.ir import rydberg, detuning
from bloqade import start, var, cast
from bloqade.ir.location import Square, Chain
import numpy as np


def test_piecewise_const():
prog = start.rydberg.detuning.uniform.piecewise_constant(
durations=[0.1, 3.1, 0.05], values=[4, 4, 7.5]
)

## inspect ir
node1 = prog
ir1 = node1._waveform
assert ir1.value == cast(7.5)
assert ir1.duration == cast(0.05)

node2 = node1.__parent__
ir2 = node2._waveform
assert ir2.value == cast(4)
assert ir2.duration == cast(3.1)

node3 = node2.__parent__
ir3 = node3._waveform
assert ir3.value == cast(4)
assert ir3.duration == cast(0.1)


def test_registers():
waveform = (
ir.Linear("initial_detuning", "initial_detuning", "up_time")
.append(ir.Linear("initial_detuning", "final_detuning", "anneal_time"))
.append(ir.Linear("final_detuning", "final_detuning", "up_time"))
)
prog1 = start.rydberg.detuning.uniform.apply(waveform)
reg = prog1.register

assert reg.n_atoms == 0
assert reg.n_dims is None


def test_scale():
prog = start
prog = (
prog.rydberg.detuning.location(1)
.scale(1.2)
.piecewise_linear([0.1, 3.8, 0.1], [-10, -10, "a", "b"])
)

## let Emit build ast
seq = prog.sequence

print(type(list(seq.value.keys())[0]))
Loc1 = list(seq.value[rydberg].value[detuning].value.keys())[0]

assert type(Loc1) == ir.ScaledLocations
assert Loc1.value[ir.Location(1)] == cast(1.2)


def test_scale_location():
prog = start.rydberg.detuning.location(1).scale(1.2).location(2).scale(3.3)

assert prog._scale == cast(3.3)
assert type(prog.__parent__) == location.Location
assert prog.__parent__.__parent__._scale == cast(1.2)


def test_build_ast_Scale():
prog = (
start.rydberg.detuning.location(1)
.scale(1.2)
.location(2)
.scale(3.3)
.piecewise_constant(durations=[0.1], values=[1])
)

# compile ast:
tmp = prog.sequence

locs = list(tmp.value[rydberg].value[detuning].value.keys())[0]
wvfm = tmp.value[rydberg].value[detuning].value[locs]

assert locs == ir.ScaledLocations(
{ir.Location(2): cast(3.3), ir.Location(1): cast(1.2)}
)
assert wvfm == ir.Constant(value=cast(1), duration=cast(0.1))


def test_spatial_var():
prog = start.rydberg.detuning.var("a")

assert prog._name == "a"


def test_issue_107():
waveform = (
ir.Linear("initial_detuning", "initial_detuning", "up_time")
Expand Down
55 changes: 55 additions & 0 deletions tests/test_builder_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from bloqade.builder.factory import (
piecewise_linear,
piecewise_constant,
constant,
linear,
)
from bloqade import cast


def test_ir_piecewise_linear():
A = piecewise_linear([0.1, 3.8, 0.2], [-10, -7, "a", "b"])

## Append type ir node
assert len(A.waveforms) == 3
assert A.waveforms[0].duration == cast(0.1)
assert A.waveforms[0].start == cast(-10)
assert A.waveforms[0].stop == cast(-7)

assert A.waveforms[1].duration == cast(3.8)
assert A.waveforms[1].start == cast(-7)
assert A.waveforms[1].stop == cast("a")

assert A.waveforms[2].duration == cast(0.2)
assert A.waveforms[2].start == cast("a")
assert A.waveforms[2].stop == cast("b")


def test_ir_const():
A = constant(value=3.415, duration=0.55)

## Constant type ir node:
assert A.value == cast(3.415)
assert A.duration == cast(0.55)


def test_ir_linear():
A = linear(start=0.5, stop=3.2, duration=0.76)

## Linear type ir node:
assert A.start == cast(0.5)
assert A.stop == cast(3.2)
assert A.duration == cast(0.76)


def test_ir_piecewise_constant():
A = piecewise_constant(durations=[0.1, 3.8, 0.2], values=[-10, "a", "b"])

assert A.waveforms[0].duration == cast(0.1)
assert A.waveforms[0].value == cast(-10)

assert A.waveforms[1].duration == cast(3.8)
assert A.waveforms[1].value == cast("a")

assert A.waveforms[2].duration == cast(0.2)
assert A.waveforms[2].value == cast("b")

0 comments on commit 6514f67

Please sign in to comment.