Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions swift/codegen/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
load("@swift_codegen_deps//:requirements.bzl", "requirement")

py_binary(
name = "codegen",
srcs = glob(["*.py"]),
srcs = glob(
["*.py"],
exclude = ["trapgen.py"],
),
visibility = ["//swift/codegen/test:__pkg__"],
deps = ["//swift/codegen/lib"],
)
Expand All @@ -12,5 +17,8 @@ py_binary(
srcs = ["trapgen.py"],
data = ["//swift/codegen/templates:cpp"],
visibility = ["//swift:__subpackages__"],
deps = ["//swift/codegen/lib"],
deps = [
"//swift/codegen/lib",
requirement("toposort"),
],
)
5 changes: 4 additions & 1 deletion swift/codegen/lib/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ class Field:
type: str
first: bool = False

@property
def cpp_name(self):
if self.name in cpp_keywords:
return self.name + "_"
return self.name

def stream(self):
# using @property breaks pystache internals here
def get_streamer(self):
if self.type == "std::string":
return lambda x: f"trapQuoted({x})"
elif self.type == "bool":
Expand Down Expand Up @@ -65,6 +67,7 @@ def __post_init__(self):
self.bases = [TagBase(b) for b in self.bases]
self.bases[0].first = True

@property
def has_bases(self):
return bool(self.bases)

Expand Down
7 changes: 2 additions & 5 deletions swift/codegen/lib/dbscheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,10 @@ def get_union(match):


def iterload(file):
data = Re.comment.sub("", file.read())
with open(file) as file:
data = Re.comment.sub("", file.read())
for e in Re.entity.finditer(data):
if e["table"]:
yield get_table(e)
elif e["union"]:
yield get_union(e)


def load(file):
return list(iterload(file))
5 changes: 3 additions & 2 deletions swift/codegen/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pystache
pyyaml
inflection
pystache
pytest
pyyaml
toposort
2 changes: 1 addition & 1 deletion swift/codegen/templates/cpp_traps.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct {{name}}Trap {

inline std::ostream &operator<<(std::ostream &out, const {{name}}Trap &e) {
out << "{{table_name}}("{{#fields}}{{^first}} << ", "{{/first}}
<< {{#stream}}e.{{cpp_name}}{{/stream}}{{/fields}} << ")";
<< {{#get_streamer}}e.{{cpp_name}}{{/get_streamer}}{{/fields}} << ")";
return out;
}
{{/traps}}
Expand Down
1 change: 1 addition & 0 deletions swift/codegen/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ py_library(
deps = [
":utils",
"//swift/codegen",
"//swift/codegen:trapgen",
],
)
for src in glob(["test_*.py"])
Expand Down
60 changes: 60 additions & 0 deletions swift/codegen/test/test_cpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import sys
from copy import deepcopy

import pytest

from swift.codegen.lib import cpp


@pytest.mark.parametrize("keyword", cpp.cpp_keywords)
def test_field_keyword_cpp_name(keyword):
f = cpp.Field(keyword, "int")
assert f.cpp_name == keyword + "_"


def test_field_cpp_name():
f = cpp.Field("foo", "int")
assert f.cpp_name == "foo"


@pytest.mark.parametrize("type,expected", [
("std::string", "trapQuoted(value)"),
("bool", '(value ? "true" : "false")'),
("something_else", "value"),
])
def test_field_get_streamer(type, expected):
f = cpp.Field("name", type)
assert f.get_streamer()("value") == expected


def test_trap_has_first_field_marked():
fields = [
cpp.Field("a", "x"),
cpp.Field("b", "y"),
cpp.Field("c", "z"),
]
expected = deepcopy(fields)
expected[0].first = True
t = cpp.Trap("table_name", "name", fields)
assert t.fields == expected


def test_tag_has_first_base_marked():
bases = ["a", "b", "c"]
expected = [cpp.TagBase("a", first=True), cpp.TagBase("b"), cpp.TagBase("c")]
t = cpp.Tag("name", bases, 0, "id")
assert t.bases == expected


@pytest.mark.parametrize("bases,expected", [
([], False),
(["a"], True),
(["a", "b"], True)
])
def test_tag_has_bases(bases, expected):
t = cpp.Tag("name", bases, 0, "id")
assert t.has_bases is expected


if __name__ == '__main__':
sys.exit(pytest.main())
102 changes: 102 additions & 0 deletions swift/codegen/test/test_dbscheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,107 @@ def test_union_has_first_case_marked():
assert [c.type for c in u.rhs] == rhs


# load tests
@pytest.fixture
def load(tmp_path):
file = tmp_path / "test.dbscheme"

def ret(yml):
write(file, yml)
return list(dbscheme.iterload(file))

return ret


def test_load_empty(load):
assert load("") == []


def test_load_one_empty_table(load):
assert load("""
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[])
]


def test_load_table_with_keyset(load):
assert load("""
#keyset[x, y,z]
test_foos();
""") == [
dbscheme.Table(name="test_foos", columns=[], keyset=dbscheme.KeySet(["x", "y", "z"]))
]


expected_columns = [
("int foo: int ref", dbscheme.Column(schema_name="foo", type="int", binding=False)),
(" int bar : int ref", dbscheme.Column(schema_name="bar", type="int", binding=False)),
("str baz_: str ref", dbscheme.Column(schema_name="baz", type="str", binding=False)),
("int x: @foo ref", dbscheme.Column(schema_name="x", type="@foo", binding=False)),
("int y: @foo", dbscheme.Column(schema_name="y", type="@foo", binding=True)),
("unique int z: @foo", dbscheme.Column(schema_name="z", type="@foo", binding=True)),
]


@pytest.mark.parametrize("column,expected", expected_columns)
def test_load_table_with_column(load, column, expected):
assert load(f"""
foos(
{column}
);
""") == [
dbscheme.Table(name="foos", columns=[deepcopy(expected)])
]


def test_load_table_with_multiple_columns(load):
columns = ",\n".join(c for c, _ in expected_columns)
expected = [deepcopy(e) for _, e in expected_columns]
assert load(f"""
foos(
{columns}
);
""") == [
dbscheme.Table(name="foos", columns=expected)
]


def test_load_multiple_table_with_columns(load):
tables = [f"table{i}({col});" for i, (col, _) in enumerate(expected_columns)]
expected = [dbscheme.Table(name=f"table{i}", columns=[deepcopy(e)]) for i, (_, e) in enumerate(expected_columns)]
assert load("\n".join(tables)) == expected


def test_union(load):
assert load("@foo = @bar | @baz | @bla;") == [
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]


def test_table_and_union(load):
assert load("""
foos();

@foo = @bar | @baz | @bla;""") == [
dbscheme.Table(name="foos", columns=[]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]


def test_comments_ignored(load):
assert load("""
// fake_table();
foos(/* x */unique /*y*/int/*
z
*/ id/* */: /* * */ @bar/*,
int ignored: int ref*/);

@foo = @bar | @baz | @bla; // | @xxx""") == [
dbscheme.Table(name="foos", columns=[dbscheme.Column(schema_name="id", type="@bar", binding=True)]),
dbscheme.Union(lhs="@foo", rhs=["@bar", "@baz", "@bla"]),
]


if __name__ == '__main__':
sys.exit(pytest.main())
Loading