Skip to content

Commit

Permalink
atdcpp: add enum storage
Browse files Browse the repository at this point in the history
  • Loading branch information
elrandar committed Apr 8, 2024
1 parent d6c0d70 commit 3cf12ec
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 4 deletions.
122 changes: 118 additions & 4 deletions atdcpp/src/lib/Codegen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1167,15 +1167,15 @@ let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_clas
let read_cases0 env loc name cases0 =
let read_cases0 env loc name cases0 sum_repr =
let ifs =
cases0
|> List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
let json_name = Atd.Json.get_json_cons orig_name an in
Inline [
Line (sprintf "if (std::string_view(x.GetString()) == \"%s\") " (single_esc json_name));
Block [
Line (sprintf "return Types::%s();" (trans env orig_name))
Line (sprintf "return Types::%s%s;" (trans env orig_name) (match sum_repr with | Cpp_annot.Variant -> "()" | Cpp_annot.Enum -> ""))
];
]
)
Expand Down Expand Up @@ -1223,7 +1223,7 @@ let sum_container env loc name cases codegen_type =
if cases0 <> [] then
[
Line "if (x.IsString()) {";
Block (read_cases0 env loc name cases0);
Block (read_cases0 env loc name cases0 Cpp_annot.Variant);
Line "}";
]
else
Expand Down Expand Up @@ -1345,13 +1345,127 @@ let sum env loc name cases codegen_type =
[Line(sprintf "typedef %s %s;" (sprintf "std::variant<%s>" type_list) (struct_name env name))]
| _ -> []
let enum_container env loc name cases codegen_type =
let cpp_struct_name = struct_name env name in
let cases0, cases1 =
List.partition (fun (loc, orig_name, unique_name, an, opt_e) ->
opt_e = None
) cases
in
let cases0_block =
if cases0 <> [] then
[
Line "if (x.IsString()) {";
Block (read_cases0 env loc name cases0 Cpp_annot.Enum);
Line "}";
]
else
[]
in
let cases1_block =
if cases1 <> [] then
error_at loc "enums with parameters are not supported"
else
[]
in
match codegen_type with
| Declaration ->
[
Line (sprintf "typedefs::%s from_json(const rapidjson::Value &x);" (cpp_struct_name));
Inline (from_json_string_declaration (sprintf "typedefs::%s" cpp_struct_name) false);
Line (sprintf "void to_json(const typedefs::%s &x, rapidjson::Writer<rapidjson::StringBuffer> &writer);" (cpp_struct_name));
Line (sprintf "std::string to_json_string(const typedefs::%s &x);" (cpp_struct_name));
]
| Definition ->
[
Line (sprintf "namespace %s {" (cpp_struct_name));
Block [
Line (sprintf "typedefs::%s from_json(const rapidjson::Value &x) {" (cpp_struct_name));
Block [
Inline cases0_block;
Inline cases1_block;
Line (sprintf "throw _atd_bad_json(\"%s\", x);"
(single_esc (struct_name env name)))
];
Line "}";
];
Block [Inline (from_json_string_definition (sprintf "typedefs::%s" cpp_struct_name) (None))];
Block [
Line (sprintf "void to_json(const typedefs::%s &x, rapidjson::Writer<rapidjson::StringBuffer> &writer) {" (cpp_struct_name));
Block [
Line "switch (x) {";
Block (
List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
Line (sprintf "case Types::%s: _atd_write_string(\"%s\", writer); break;" (trans env orig_name) (trans env orig_name))
) cases
);
Line ("}");
];
Line ("}");
];
Block [Line (sprintf "std::string to_json_string(const typedefs::%s &x) {" (cpp_struct_name));
Block [
Line ("rapidjson::StringBuffer buffer;");
Line ("rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);");
Line ("to_json(x, writer);");
Line "return buffer.GetString();"
];
Line "}";];
Line ("}");
]
| _ -> []
let enum env loc name cases codegen_type =
let cases =
List.map (fun (x : variant) ->
match x with
| Variant (loc, (orig_name, an), opt_e) ->
let unique_name = create_struct_name env orig_name in
(loc, orig_name, unique_name, an, opt_e)
| Inherit _ -> assert false
) cases
in
(* let case_classes =
List.map (fun x -> Inline (case_class env name x codegen_type)) cases
in *)
let container_class = enum_container env loc name cases codegen_type in
match codegen_type with
| Declaration ->
[
Line (sprintf "namespace %s {" (struct_name env name));
Block [
Line (sprintf "typedef typedefs::%s Types;" (struct_name env name));
Line ("");
Inline container_class;
];
Line (sprintf "}");
]
| Definition ->
[
Inline container_class;
] |> double_spaced
| Struct_typedef ->
let type_list =
List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
Line (sprintf "%s," (trans env orig_name))) cases in
[
Line (sprintf "enum class %s {" (struct_name env name));
Block type_list;
Line (sprintf "};");
]
| _ -> []
let type_def env ((loc, (name, param, an), e) : A.type_def) codegen_type : B.t =
if param <> [] then
not_implemented loc "parametrized type";
let unwrap e =
match e with
| Sum (loc, cases, an) ->
sum env loc name cases codegen_type
(match (Cpp_annot.get_cpp_sumtype_repr an) with
| Variant -> sum env loc name cases codegen_type
| Enum -> enum env loc name cases codegen_type)
| Record (loc, fields, an) ->
record env loc name fields an codegen_type
| Tuple _
Expand Down
16 changes: 16 additions & 0 deletions atdcpp/src/lib/Cpp_annot.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ type assoc_repr =
| List
| Dict

type sumtype_repr =
| Variant
| Enum

type atd_cpp_wrap = {
cpp_wrap_t : string;
cpp_wrap : string;
Expand All @@ -22,6 +26,18 @@ let get_cpp_default an : string option =
~field:"default"
an

let get_cpp_sumtype_repr an : sumtype_repr =
Atd.Annot.get_field
~parse:(function
| "variant" -> Some Variant
| "enum" -> Some Enum
| _ -> None
)
~default:Variant
~sections:["cpp"]
~field:"repr"
an

let get_cpp_assoc_repr an : assoc_repr =
Atd.Annot.get_field
~parse:(function
Expand Down
15 changes: 15 additions & 0 deletions atdcpp/src/lib/Cpp_annot.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ type assoc_repr =
| List
| Dict

(** Whether a sum type must be represented in cpp as a variant or as an enum.
This is independent of the JSON representation.
*)
type sumtype_repr =
| Variant
| Enum

(** Inspection of annotations placed on sum types such as
[type foo = A | B | C <cpp repr="enum">].
Permissible values for the [repr] field are ["enum"] and ["variant"].
The default is ["variant"].
*)
val get_cpp_sumtype_repr : Atd.Annot.t -> sumtype_repr


(** Inspect annotations placed on lists of pairs such as
[(string * foo) list <cpp repr="dict">].
Permissible values for the [repr] field are ["dict"] and ["list"].
Expand Down
8 changes: 8 additions & 0 deletions atdcpp/test/atd-input/everything.atd
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<cpp include="<stdint.h>">
(* <cpp namespace="types::atd"> *)

type kind = [
| Root (* class name conflict *)
Expand All @@ -12,6 +13,12 @@ type frozen = [
| B of int
]

type enum_sumtype = [
| A
| B
| C
] <cpp repr="enum">

type ('a, 'b) parametrized_record = {
field_a: 'a;
~field_b: 'b list;
Expand Down Expand Up @@ -46,6 +53,7 @@ type root = {
wrapped: st wrap <cpp t="uint16_t" wrap="[](typedefs::St st){return st - 1;}" unwrap="[](uint16_t e){return e + 1;}">;
aaa: alias_of_alias_of_alias;
item: string wrap <cpp t="int" wrap="std::stoi" unwrap="std::to_string">;
ee : enum_sumtype;
}


Expand Down
42 changes: 42 additions & 0 deletions atdcpp/test/cpp-expected/everything_atd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,43 @@ namespace Kind {
}


namespace EnumSumtype {
typedefs::EnumSumtype from_json(const rapidjson::Value &x) {
if (x.IsString()) {
if (std::string_view(x.GetString()) == "A")
return Types::A;
if (std::string_view(x.GetString()) == "B")
return Types::B;
if (std::string_view(x.GetString()) == "C")
return Types::C;
throw _atd_bad_json("EnumSumtype", x);
}
throw _atd_bad_json("EnumSumtype", x);
}
typedefs::EnumSumtype from_json_string(const std::string &s) {
rapidjson::Document doc;
doc.Parse(s.c_str());
if (doc.HasParseError()) {
throw AtdException("Failed to parse JSON");
}
return from_json(doc);
}
void to_json(const typedefs::EnumSumtype &x, rapidjson::Writer<rapidjson::StringBuffer> &writer) {
switch (x) {
case Types::A: _atd_write_string("A", writer); break;
case Types::B: _atd_write_string("B", writer); break;
case Types::C: _atd_write_string("C", writer); break;
}
}
std::string to_json_string(const typedefs::EnumSumtype &x) {
rapidjson::StringBuffer buffer;
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
to_json(x, writer);
return buffer.GetString();
}
}


namespace Alias3 {
typedefs::Alias3 from_json(const rapidjson::Value &doc) {
return _atd_read_wrap([](const auto& v){return _atd_read_int(v);}, [](const auto &e){return static_cast<uint32_t>(e);},doc);
Expand Down Expand Up @@ -930,6 +967,9 @@ Root Root::from_json(const rapidjson::Value & doc) {
if (doc.HasMember("item"))
record.item = _atd_read_wrap([](const auto& v){return _atd_read_string(v);}, [](const auto &e){return std::stoi(e);},doc["item"]);
else record.item = _atd_missing_json_field<decltype(record.item)>("Root", "item");
if (doc.HasMember("ee"))
record.ee = EnumSumtype::from_json(doc["ee"]);
else record.ee = _atd_missing_json_field<decltype(record.ee)>("Root", "ee");
return record;
}
Root Root::from_json_string(const std::string &s) {
Expand Down Expand Up @@ -1010,6 +1050,8 @@ void Root::to_json(const Root &t, rapidjson::Writer<rapidjson::StringBuffer> &wr
AliasOfAliasOfAlias::to_json(t.aaa, writer);
writer.Key("item");
_atd_write_wrap([](const auto &v, auto &w){_atd_write_string(v, w);}, [](const auto &e){return std::to_string(e);}, t.item, writer);
writer.Key("ee");
EnumSumtype::to_json(t.ee, writer);
writer.EndObject();
}
std::string Root::to_json_string(const Root &t) {
Expand Down
16 changes: 16 additions & 0 deletions atdcpp/test/cpp-expected/everything_atd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ namespace typedefs {
typedef RecursiveClass RecursiveClass;
typedef ThreeLevelNestedListRecord ThreeLevelNestedListRecord;
typedef StructWithRecursiveVariant StructWithRecursiveVariant;
enum class EnumSumtype {
A,
B,
C,
};
typedef IntFloatParametrizedRecord IntFloatParametrizedRecord;
typedef Root Root;
typedef RequireField RequireField;
Expand Down Expand Up @@ -195,6 +200,16 @@ namespace Kind {
}


namespace EnumSumtype {
typedef typedefs::EnumSumtype Types;

typedefs::EnumSumtype from_json(const rapidjson::Value &x);
typedefs::EnumSumtype from_json_string(const std::string &s);
void to_json(const typedefs::EnumSumtype &x, rapidjson::Writer<rapidjson::StringBuffer> &writer);
std::string to_json_string(const typedefs::EnumSumtype &x);
}


namespace Alias3 {
typedefs::Alias3 from_json(const rapidjson::Value &doc);
typedefs::Alias3 from_json_string(const std::string &s);
Expand Down Expand Up @@ -273,6 +288,7 @@ struct Root {
uint16_t wrapped;
typedefs::AliasOfAliasOfAlias aaa;
int item;
typedefs::EnumSumtype ee;

static Root from_json(const rapidjson::Value & doc);
static Root from_json_string(const std::string &s);
Expand Down
1 change: 1 addition & 0 deletions atdcpp/test/cpp-tests/test_atdd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ int main() {
root.wrapped = 1;
root.aaa = -90;
root.item = 45;
root.ee = EnumSumtype::Types::B;

std::string json = root.to_json_string();
Root rootFromJson = Root::from_json_string(json);
Expand Down

0 comments on commit 3cf12ec

Please sign in to comment.