Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug related to codegen for sample python functions #307

Merged
merged 32 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
be47097
add more tests
Jul 14, 2023
a87ab8a
Merge branch 'main' into khwu/enhance_testv4
Jul 14, 2023
a0c9d64
add more integration tests for codegen
Jul 14, 2023
0ca26b3
fix bug when ConstantCodeGen hit Negative wvfm node.
Jul 17, 2023
0e906a1
Merge branch 'main' into khwu/enhance_testv4
Jul 17, 2023
69d9ce1
tmp
Jul 17, 2023
6d0c109
fix bug when slice produce duplicate time points
Jul 17, 2023
a8fccb7
Merge branch 'khwu/enhance_testv4' into khwu/enhance_testv5
Jul 17, 2023
fea7e20
tmp
Jul 17, 2023
4a1627e
Merge branch 'main' into khwu/enhance_testv5
Jul 17, 2023
698e834
tmp
Jul 17, 2023
c92ef69
Merge branch 'main' into khwu/enhance_testv5
Jul 17, 2023
af7c0da
adding fix for Sample interpolation.
weinbe58 Jul 17, 2023
af59e4c
tmp
Jul 17, 2023
9bfe755
Merge branch 'phil/fix-sample-waveform-ast' into khwu/enhance_testv5
Jul 17, 2023
61bf678
add framework for pretty print testing and fix bugs #297
Jul 17, 2023
2c725e7
fix bug in print with children
Jul 17, 2023
dfc4fc8
remove [html] from .coveragerc and fix bugs in print testing
Jul 17, 2023
48c181e
finished assignment scan tests
Jul 17, 2023
bf58a4d
fix bug in codegen slice
Jul 18, 2023
72e0fdb
fix conflict
Jul 18, 2023
0348945
making `Append` give correct value for `eval_decimal(duration)`.
weinbe58 Jul 18, 2023
5e8d290
add more tests for batch_assign
Jul 18, 2023
c0c15ba
Merge branch '302-record-does-not-properly-record-the-current-value' …
Jul 18, 2023
cc36d42
add more testing cases
Jul 18, 2023
4fa82cd
Merge branch 'main' into khwu/enhance_testv6
Jul 18, 2023
0263bd1
add more tests and fix bugs related to fn() sampling
Jul 18, 2023
e967816
Merge branch 'main' into khwu/enhance_testv8
Jul 18, 2023
68db7b2
Merge branch 'main' into khwu/enhance_testv8
weinbe58 Jul 19, 2023
b01c23d
fixing `samples` of `Sample` + adding example of using codegen directly.
weinbe58 Jul 19, 2023
80013b8
fix bug in matching of namedpulese
Jul 20, 2023
f37bd8e
Merge branch 'main' into khwu/enhance_testv8
Roger-luo Jul 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more testing cases
  • Loading branch information
Kai-Hsin Wu authored and Kai-Hsin Wu committed Jul 18, 2023
commit cc36d429d30dc4cf08be7a10bd64de45234d2209
47 changes: 47 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from bloqade.ir import Interval
import pytest
from bloqade import cast
from io import StringIO
from IPython.lib.pretty import PrettyPrinter as PP


def test_lvlcouple_base():
Expand All @@ -27,13 +29,25 @@ def test_lvlcouple_hf():
assert lc.print_node() == "HyperfineLevelCoupling"
assert lc.__repr__() == "hyperfine"

mystdout = StringIO()
p = PP(mystdout)
lc._repr_pretty_(p, 2)

assert mystdout.getvalue() == "HyperfineLevelCoupling\n"


def test_lvlcouple_ryd():
lc = rydberg

assert lc.print_node() == "RydbergLevelCoupling"
assert lc.__repr__() == "rydberg"

mystdout = StringIO()
p = PP(mystdout)
lc._repr_pretty_(p, 2)

assert mystdout.getvalue() == "RydbergLevelCoupling\n"


def test_seqence():
# seq empty
Expand Down Expand Up @@ -75,6 +89,22 @@ def test_slice_sequence():
assert slc.children() == {"sequence": seq_full, "interval": itvl}
assert slc.print_node() == "Slice"

mystdout = StringIO()
p = PP(mystdout)
slc._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Slice\n"
+ "├─ sequence ⇒ Sequence\n"
+ "│ └─ RydbergLevelCoupling ⇒ Pulse\n"
+ "│ └─ Detuning ⇒ Field\n"
+ "⋮\n"
+ "└─ interval ⇒ Interval\n"
+ " ├─ start ⇒ Literal: 0\n"
+ " └─ stop ⇒ Literal: 1.5"
)


def test_append_sequence():
f = Field({Uniform: Linear(start=1.0, stop=2.0, duration=3.0)})
Expand All @@ -86,6 +116,23 @@ def test_append_sequence():
assert app.children() == [seq_full, seq_full]
assert app.print_node() == "Append"

mystdout = StringIO()
p = PP(mystdout)
app._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Append\n"
+ "├─ Sequence\n"
+ "│ └─ RydbergLevelCoupling ⇒ Pulse\n"
+ "│ └─ Detuning ⇒ Field\n"
+ "⋮\n"
+ "└─ Sequence\n"
+ " └─ RydbergLevelCoupling ⇒ Pulse\n"
+ " └─ Detuning ⇒ Field\n"
+ "⋮\n"
)


seq = Sequence(
{
Expand Down
15 changes: 15 additions & 0 deletions tests/test_submission_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from bloqade.submission.base import SubmissionBackend
import pytest


def test_submission_base():
A = SubmissionBackend()

with pytest.raises(NotImplementedError):
A.cancel_task("1")

with pytest.raises(NotImplementedError):
A.task_results("1")

with pytest.raises(NotImplementedError):
A.task_status("1")
130 changes: 130 additions & 0 deletions tests/test_waveform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from decimal import Decimal
import pytest
import numpy as np
from io import StringIO
from IPython.lib.pretty import PrettyPrinter as PP


def test_wvfm_base():
Expand Down Expand Up @@ -54,6 +56,16 @@ def test_wvfm_constant():
assert wf.eval_decimal(clock_s=Decimal("6.0")) == 0
assert wf.children() == {"value": cast(1.0), "duration": cast(3.0)}

mystdout = StringIO()
p = PP(mystdout)

wf._repr_pretty_(p, 0)

assert (
mystdout.getvalue()
== "Constant\n├─ value ⇒ Literal: 1.0\n⋮\n└─ duration ⇒ Literal: 3.0⋮\n"
)


def test_wvfm_pyfn():
def my_func(time, *, omega, phi=0, amplitude):
Expand Down Expand Up @@ -98,6 +110,13 @@ def my_func3(time, omega, **phi):
assert wf.children() == {"duration": cast(1.0)}
assert wf.duration == cast(1.0)

mystdout = StringIO()
p = PP(mystdout)

wf._repr_pretty_(p, 0)

assert mystdout.getvalue() == "PythonFn: my_func\n└─ duration ⇒ Literal: 1.0⋮\n"


def test_wvfm_app():
wf = Linear(start=1.0, stop=2.0, duration=3.0)
Expand All @@ -107,6 +126,14 @@ def test_wvfm_app():

assert wf3.print_node() == "Append"
assert wf3.children() == [wf, wf2]
assert wf3.eval_decimal(Decimal(10)) == Decimal(0)

mystdout = StringIO()
p = PP(mystdout)

wf3._repr_pretty_(p, 0)

assert mystdout.getvalue() == "Append\n├─ Linear\n⋮\n└─ Constant\n⋮\n"


def test_wvfm_neg():
Expand All @@ -119,6 +146,19 @@ def test_wvfm_neg():

assert wf2.eval_decimal(Decimal("0.5")) == Decimal("-1.0")

mystdout = StringIO()
p = PP(mystdout)

wf2._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "-\n"
+ "└─ Constant\n"
+ " ├─ value ⇒ Literal: 1.0\n"
+ " └─ duration ⇒ Literal: 3.0"
)


def test_wvfm_scale():
wf = Constant(value=1.0, duration=3.0)
Expand All @@ -137,6 +177,20 @@ def test_wvfm_scale():

assert wf3.eval_decimal(Decimal("0.5")) == Decimal("2.0")

mystdout = StringIO()
p = PP(mystdout)

wf3._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Scale\n"
+ "├─ Literal: 2.0\n"
+ "└─ Constant\n"
+ " ├─ value ⇒ Literal: 1.0\n"
+ " └─ duration ⇒ Literal: 3.0"
)


def test_wvfn_add():
wf = Constant(value=1.0, duration=3.0)
Expand All @@ -151,6 +205,23 @@ def test_wvfn_add():
assert wf3.eval_decimal(Decimal("0")) == Decimal("2.0")
assert wf3.eval_decimal(Decimal("2.5")) == Decimal("1.0")

mystdout = StringIO()
p = PP(mystdout)

wf3._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "+\n"
+ "├─ Constant\n"
+ "│ ├─ value ⇒ Literal: 1.0\n"
+ "│ └─ duration ⇒ Literal: 3.0\n"
+ "└─ Linear\n"
+ " ├─ start ⇒ Literal: 1.0\n"
+ " ├─ stop ⇒ Literal: 2.0\n"
+ " └─ duration ⇒ Literal: 2.0"
)


def test_wvfn_rec():
wf = Linear(start=1.0, stop=2.0, duration=3.0)
Expand All @@ -164,6 +235,21 @@ def test_wvfn_rec():
assert re.eval_decimal(Decimal("0")) == Decimal("1.0")
assert re.duration == cast(3.0)

mystdout = StringIO()
p = PP(mystdout)

re._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Record\n"
+ "├─ Waveform ⇒ Linear\n"
+ "│ ├─ start ⇒ Literal: 1.0\n"
+ "│ ├─ stop ⇒ Literal: 2.0\n"
+ "│ └─ duration ⇒ Literal: 3.0\n"
+ "└─ Variable ⇒ Variable: tst"
)


def test_wvfn_poly():
wf = Poly(checkpoints=[cast(1), cast(2), cast(3)], duration=10)
Expand Down Expand Up @@ -264,6 +350,21 @@ def test_wvfn_slice():
assert wf.eval_decimal(Decimal("0.2")) == 2.0
assert wf.children() == [wv, iv]

mystdout = StringIO()
p = PP(mystdout)
wf._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Slice\n"
+ "├─ Constant\n"
+ "│ ├─ value ⇒ Literal: 2.0\n"
+ "│ └─ duration ⇒ Literal: 3.0\n"
+ "└─ Interval\n"
+ " ├─ start ⇒ Literal: 0\n"
+ " └─ stop ⇒ Literal: 0.3"
)

iv_err1 = Interval(None, None)
wf2 = Slice(wv, iv_err1)
with pytest.raises(ValueError):
Expand All @@ -290,6 +391,19 @@ def test_wvfm_align():
assert wf3.print_node() == "AlignedWaveform"
assert wf3.children() == {"Waveform": wv, "Alignment": "Right", "Value": "Left"}

mystdout = StringIO()
p = PP(mystdout)
wf3._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "AlignedWaveform\n"
+ "├─ Waveform ⇒ Constant\n"
+ "│ ├─ value ⇒ Literal: 2.0\n"
+ "│ └─ duration ⇒ Literal: 3.0\n"
+ "├─ Alignment ⇒ Right\n└─ Value ⇒ Left\n"
)


def test_wvfm_sample():
def my_cos(time):
Expand All @@ -305,13 +419,29 @@ def my_cos(time):
assert wf.print_node() == "Sample constant"
assert wf.children() == {"Waveform": wv, "sample_step": dt}
assert wf.eval_decimal(Decimal(0.05)) == my_cos(0)
assert float(wf.eval_decimal(Decimal(0))) == my_cos(0)

wf2 = Sample(wv, Interpolation.Linear, dt)

assert wf2.print_node() == "Sample linear"
assert wf2.children() == {"Waveform": wv, "sample_step": dt}
slope = (my_cos(0.1) - my_cos(0)) / 0.1
assert float(wf2.eval_decimal(Decimal(0.05))) == float(my_cos(0) + slope * 0.05)
assert float(wf2.eval_decimal(Decimal(3))) == 0
assert float(wf2.eval_decimal(Decimal(0))) == my_cos(0)

mystdout = StringIO()
p = PP(mystdout)

wf2._repr_pretty_(p, 2)

assert (
mystdout.getvalue()
== "Sample linear\n"
+ "├─ Waveform ⇒ PythonFn: my_cos\n"
+ "│ └─ duration ⇒ Literal: 1.0\n"
+ "└─ sample_step ⇒ Literal: 0.1"
)


"""
Expand Down
Loading