Skip to content

Commit eadc043

Browse files
authored
@newtype on array/map wrappers (#237)
* @newtype on array/map wrappers Previously this only worked on primitive wrappers. Now you can do e.g. `foo = [* uint] ; @newtype` which was ignored before. * cargo fmt * tests added
1 parent e75a0ad commit eadc043

File tree

7 files changed

+160
-21
lines changed

7 files changed

+160
-21
lines changed

src/intermediate.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,21 @@ impl AliasInfo {
9090
}
9191
}
9292

93+
#[derive(Debug, Clone)]
94+
pub struct PlainGroupInfo<'a> {
95+
group: Option<cddl::ast::Group<'a>>,
96+
rule_metadata: RuleMetadata,
97+
}
98+
99+
impl<'a> PlainGroupInfo<'a> {
100+
pub fn new(group: Option<cddl::ast::Group<'a>>, rule_metadata: RuleMetadata) -> Self {
101+
Self {
102+
group,
103+
rule_metadata,
104+
}
105+
}
106+
}
107+
93108
#[derive(Debug)]
94109
pub struct IntermediateTypes<'a> {
95110
// Storing the cddl::Group is the easiest way to go here even after the parse/codegen split.
@@ -99,7 +114,7 @@ pub struct IntermediateTypes<'a> {
99114
// delayed until the point where it is referenced via self.set_rep_if_plain_group(rep)
100115
// Some(group) = directly defined in .cddl (must call set_plain_group_representatio() later)
101116
// None = indirectly generated due to a group choice (no reason to call set_rep_if_plain_group() later but it won't crash)
102-
plain_groups: BTreeMap<RustIdent, Option<cddl::ast::Group<'a>>>,
117+
plain_groups: BTreeMap<RustIdent, PlainGroupInfo<'a>>,
103118
type_aliases: BTreeMap<AliasIdent, AliasInfo>,
104119
rust_structs: BTreeMap<RustIdent, RustStruct>,
105120
prelude_to_emit: BTreeSet<String>,
@@ -642,8 +657,8 @@ impl<'a> IntermediateTypes<'a> {
642657
}
643658

644659
// see self.plain_groups comments
645-
pub fn mark_plain_group(&mut self, ident: RustIdent, group: Option<cddl::ast::Group<'a>>) {
646-
self.plain_groups.insert(ident, group);
660+
pub fn mark_plain_group(&mut self, ident: RustIdent, group_info: PlainGroupInfo<'a>) {
661+
self.plain_groups.insert(ident, group_info);
647662
}
648663

649664
// see self.plain_groups comments
@@ -656,7 +671,8 @@ impl<'a> IntermediateTypes<'a> {
656671
) {
657672
if let Some(plain_group) = self.plain_groups.get(ident) {
658673
// the clone is to get around the borrow checker
659-
if let Some(group) = plain_group.as_ref().cloned() {
674+
let plain_group = plain_group.clone();
675+
if let Some(group) = plain_group.group.as_ref() {
660676
// we are defined via .cddl and thus need to register a concrete
661677
// representation of the plain group
662678
if let Some(rust_struct) = self.rust_structs.get(ident) {
@@ -673,11 +689,12 @@ impl<'a> IntermediateTypes<'a> {
673689
crate::parsing::parse_group(
674690
self,
675691
parent_visitor,
676-
&group,
692+
group,
677693
ident,
678694
rep,
679695
None,
680696
None,
697+
&plain_group.rule_metadata,
681698
cli,
682699
);
683700
}

src/main.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ pub(crate) mod utils;
99

1010
use clap::Parser;
1111
use cli::Cli;
12+
use comment_ast::RuleMetadata;
1213
use generation::GenerationScope;
13-
use intermediate::{CDDLIdent, IntermediateTypes, RustIdent};
14+
use intermediate::{CDDLIdent, IntermediateTypes, PlainGroupInfo, RustIdent};
1415
use once_cell::sync::Lazy;
1516
use parsing::{parse_rule, rule_ident, rule_is_scope_marker};
1617

@@ -128,10 +129,16 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
128129
if let cddl::ast::Rule::Group { rule, .. } = cddl_rule {
129130
// Freely defined group - no need to generate anything outside of group module
130131
match &rule.entry {
131-
cddl::ast::GroupEntry::InlineGroup { group, .. } => {
132+
cddl::ast::GroupEntry::InlineGroup {
133+
group,
134+
comments_after_group,
135+
..
136+
} => {
137+
assert_eq!(group.group_choices.len(), 1);
138+
let rule_metadata = RuleMetadata::from(comments_after_group.as_ref());
132139
types.mark_plain_group(
133140
RustIdent::new(CDDLIdent::new(rule.name.to_string())),
134-
Some(group.clone()),
141+
PlainGroupInfo::new(Some(group.clone()), rule_metadata),
135142
);
136143
}
137144
x => panic!("Group rule with non-inline group? {:?}", x),

src/parsing.rs

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ use std::collections::BTreeMap;
66
use crate::comment_ast::{merge_metadata, metadata_from_comments, RuleMetadata};
77
use crate::intermediate::{
88
AliasInfo, CBOREncodingOperation, CDDLIdent, ConceptualRustType, EnumVariant, FixedValue,
9-
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, Primitive, Representation,
10-
RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType, VariantIdent,
9+
GenericDef, GenericInstance, IntermediateTypes, ModuleScope, PlainGroupInfo, Primitive,
10+
Representation, RustField, RustIdent, RustRecord, RustStruct, RustStructType, RustType,
11+
VariantIdent,
1112
};
1213
use crate::utils::{
1314
append_number_if_duplicate, convert_to_camel_case, convert_to_snake_case,
@@ -640,6 +641,7 @@ fn parse_type(
640641
Representation::Map,
641642
outer_tag,
642643
generic_params,
644+
&rule_metadata,
643645
cli,
644646
);
645647
}
@@ -654,6 +656,7 @@ fn parse_type(
654656
Representation::Array,
655657
outer_tag,
656658
generic_params,
659+
&rule_metadata,
657660
cli,
658661
);
659662
}
@@ -1222,6 +1225,7 @@ fn rust_type_from_type2(
12221225
Representation::Array,
12231226
None,
12241227
None,
1228+
&rule_metadata,
12251229
cli,
12261230
);
12271231
// we aren't returning an array, but rather a struct where the fields are ordered
@@ -1446,27 +1450,59 @@ fn parse_group_choice(
14461450
rep: Representation,
14471451
tag: Option<usize>,
14481452
generic_params: Option<Vec<RustIdent>>,
1453+
parent_rule_metadata: Option<&RuleMetadata>,
14491454
cli: &Cli,
14501455
) {
14511456
let rule_metadata = RuleMetadata::from(
14521457
get_comment_after(parent_visitor, &CDDLType::from(group_choice), None).as_ref(),
14531458
);
1459+
let rule_metadata = if let Some(parent_rule_metadata) = parent_rule_metadata {
1460+
merge_metadata(&rule_metadata, parent_rule_metadata)
1461+
} else {
1462+
rule_metadata
1463+
};
14541464
let rust_struct = match parse_group_type(types, parent_visitor, group_choice, rep, cli) {
14551465
GroupParsingType::HomogenousArray(element_type) => {
1456-
// Array - homogeneous element type with proper occurence operator
1457-
RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type)
1466+
if rule_metadata.is_newtype {
1467+
// generate newtype over array
1468+
RustStruct::new_wrapper(
1469+
name.clone(),
1470+
tag,
1471+
Some(&rule_metadata),
1472+
ConceptualRustType::Array(Box::new(element_type)).into(),
1473+
None,
1474+
)
1475+
} else {
1476+
// Array - homogeneous element type with proper occurence operator
1477+
RustStruct::new_array(name.clone(), tag, Some(&rule_metadata), element_type)
1478+
}
14581479
}
14591480
GroupParsingType::HomogenousMap(key_type, value_type) => {
1460-
// Table map - homogeneous key/value types
1461-
RustStruct::new_table(
1462-
name.clone(),
1463-
tag,
1464-
Some(&rule_metadata),
1465-
key_type,
1466-
value_type,
1467-
)
1481+
if rule_metadata.is_newtype {
1482+
// generate newtype over map
1483+
RustStruct::new_wrapper(
1484+
name.clone(),
1485+
tag,
1486+
Some(&rule_metadata),
1487+
ConceptualRustType::Map(Box::new(key_type), Box::new(value_type)).into(),
1488+
None,
1489+
)
1490+
} else {
1491+
// Table map - homogeneous key/value types
1492+
RustStruct::new_table(
1493+
name.clone(),
1494+
tag,
1495+
Some(&rule_metadata),
1496+
key_type,
1497+
value_type,
1498+
)
1499+
}
14681500
}
14691501
GroupParsingType::Heterogenous | GroupParsingType::WrappedBasicGroup(_) => {
1502+
assert!(
1503+
!rule_metadata.is_newtype,
1504+
"Can only use @newtype on primtives + heterogenious arrays/maps"
1505+
);
14701506
// Heterogenous map or array with defined key/value pairs in the cddl like a struct
14711507
let record =
14721508
parse_record_from_group_choice(types, rep, parent_visitor, group_choice, cli);
@@ -1489,6 +1525,7 @@ pub fn parse_group(
14891525
rep: Representation,
14901526
tag: Option<usize>,
14911527
generic_params: Option<Vec<RustIdent>>,
1528+
parent_rule_metadata: &RuleMetadata,
14921529
cli: &Cli,
14931530
) {
14941531
if group.group_choices.len() == 1 {
@@ -1501,12 +1538,14 @@ pub fn parse_group(
15011538
rep,
15021539
tag,
15031540
generic_params,
1541+
Some(parent_rule_metadata),
15041542
cli,
15051543
);
15061544
} else {
15071545
if generic_params.is_some() {
15081546
todo!("{}: generic group choices not supported", name);
15091547
}
1548+
assert!(!parent_rule_metadata.is_newtype);
15101549
// Generate Enum object that is not exposed to wasm, since wasm can't expose
15111550
// fully featured rust enums via wasm_bindgen
15121551

@@ -1570,7 +1609,10 @@ pub fn parse_group(
15701609
let ident_name = rule_metadata.name.unwrap_or_else(|| format!("{name}{i}"));
15711610
// General case, GroupN type identifiers and generate group choice since it's inlined here
15721611
let variant_name = RustIdent::new(CDDLIdent::new(ident_name));
1573-
types.mark_plain_group(variant_name.clone(), None);
1612+
types.mark_plain_group(
1613+
variant_name.clone(),
1614+
PlainGroupInfo::new(None, RuleMetadata::default()),
1615+
);
15741616
parse_group_choice(
15751617
types,
15761618
parent_visitor,
@@ -1579,6 +1621,7 @@ pub fn parse_group(
15791621
rep,
15801622
None,
15811623
generic_params.clone(),
1624+
None,
15821625
cli,
15831626
);
15841627
let name = VariantIdent::new_rust(variant_name.clone());

tests/core/input.cddl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ inline_wrapper = [{ * text => text }]
185185
top_level_array = [* uint]
186186
top_level_single_elem = [uint]
187187

188+
wrapper_table = { * uint => uint } ; @newtype
189+
wrapper_list = [ * uint ] ; @newtype
190+
188191
overlapping_inlined = [
189192
; @name one
190193
0 //

tests/core/tests.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,4 +531,35 @@ mod tests {
531531
].into_iter().flatten().clone().collect::<Vec<u8>>();
532532
assert_eq!(expected_bytes, struct_with_custom_bytes.to_cbor_bytes());
533533
}
534+
535+
#[test]
536+
fn wrapper_table() {
537+
use cbor_event::Sz;
538+
let bytes = vec![
539+
map_sz(3, Sz::Inline),
540+
cbor_int(5, Sz::Inline),
541+
cbor_int(4, Sz::Inline),
542+
cbor_int(3, Sz::Inline),
543+
cbor_int(2, Sz::Inline),
544+
cbor_int(1, Sz::Inline),
545+
cbor_int(0, Sz::Inline),
546+
].into_iter().flatten().clone().collect::<Vec<u8>>();
547+
let from_bytes = WrapperTable::from_cbor_bytes(&bytes).unwrap();
548+
deser_test(&from_bytes);
549+
}
550+
551+
#[test]
552+
fn wrapper_list() {
553+
use cbor_event::Sz;
554+
let bytes = vec![
555+
arr_sz(5, Sz::Inline),
556+
cbor_int(5, Sz::Inline),
557+
cbor_int(4, Sz::Inline),
558+
cbor_int(3, Sz::Inline),
559+
cbor_int(2, Sz::Inline),
560+
cbor_int(1, Sz::Inline),
561+
].into_iter().flatten().clone().collect::<Vec<u8>>();
562+
let from_bytes = WrapperList::from_cbor_bytes(&bytes).unwrap();
563+
deser_test(&from_bytes);
564+
}
534565
}

tests/preserve-encodings/input.cddl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,6 @@ struct_with_custom_serialization = [
211211
tagged1: #6.9(custom_bytes),
212212
tagged2: #6.9(uint), ; @custom_serialize write_tagged_uint_str @custom_deserialize read_tagged_uint_str
213213
]
214+
215+
wrapper_table = { * uint => uint } ; @newtype
216+
wrapper_list = [ * uint ] ; @newtype

tests/preserve-encodings/tests.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,4 +1264,39 @@ mod tests {
12641264
}
12651265
}
12661266
}
1267+
1268+
#[test]
1269+
fn wrapper_table() {
1270+
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
1271+
for def_enc in &def_encodings {
1272+
let irregular_bytes = vec![
1273+
map_sz(3, *def_enc),
1274+
cbor_int(5, *def_enc),
1275+
cbor_int(4, *def_enc),
1276+
cbor_int(3, *def_enc),
1277+
cbor_int(2, *def_enc),
1278+
cbor_int(1, *def_enc),
1279+
cbor_int(0, *def_enc),
1280+
].into_iter().flatten().clone().collect::<Vec<u8>>();
1281+
let from_bytes = WrapperTable::from_cbor_bytes(&irregular_bytes).unwrap();
1282+
assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes);
1283+
}
1284+
}
1285+
1286+
#[test]
1287+
fn wrapper_list() {
1288+
let def_encodings = vec![Sz::Inline, Sz::One, Sz::Two, Sz::Four, Sz::Eight];
1289+
for def_enc in &def_encodings {
1290+
let irregular_bytes = vec![
1291+
arr_sz(5, *def_enc),
1292+
cbor_int(5, *def_enc),
1293+
cbor_int(4, *def_enc),
1294+
cbor_int(3, *def_enc),
1295+
cbor_int(2, *def_enc),
1296+
cbor_int(1, *def_enc),
1297+
].into_iter().flatten().clone().collect::<Vec<u8>>();
1298+
let from_bytes = WrapperList::from_cbor_bytes(&irregular_bytes).unwrap();
1299+
assert_eq!(from_bytes.to_cbor_bytes(), irregular_bytes);
1300+
}
1301+
}
12671302
}

0 commit comments

Comments
 (0)