Skip to content

Commit

Permalink
Fixed perfect derives of implicit/explicit tagging of fields
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Nov 17, 2024
1 parent cbe549b commit f90f20d
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 13 deletions.
68 changes: 55 additions & 13 deletions asn1_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre
all_field_types(&input.data),
syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>),
syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>),
false,
);
let (impl_generics, _, where_clause) = generics.split_for_impl();

Expand Down Expand Up @@ -65,6 +66,7 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr
all_field_types(&input.data),
syn::parse_quote!(asn1::Asn1Writable),
syn::parse_quote!(asn1::Asn1DefinedByWritable<asn1::ObjectIdentifier>),
true,
);
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

Expand Down Expand Up @@ -267,7 +269,7 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime {
generics.lifetimes().next().unwrap().lifetime.clone()
}

fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> {
fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, OpType, bool)> {
let mut field_types = vec![];
match data {
syn::Data::Struct(v) => {
Expand All @@ -283,7 +285,7 @@ fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> {
field_types
}

fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Fields) {
fn add_field_types(field_types: &mut Vec<(syn::Type, OpType, bool)>, fields: &syn::Fields) {
match fields {
syn::Fields::Named(v) => {
for f in &v.named {
Expand All @@ -299,16 +301,17 @@ fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Field
}
}

fn add_field_type(field_types: &mut Vec<(syn::Type, bool)>, f: &syn::Field) {
let (op_type, _) = extract_field_properties(&f.attrs);
field_types.push((f.ty.clone(), matches!(op_type, OpType::DefinedBy(_))));
fn add_field_type(field_types: &mut Vec<(syn::Type, OpType, bool)>, f: &syn::Field) {
let (op_type, default) = extract_field_properties(&f.attrs);
field_types.push((f.ty.clone(), op_type, default.is_some()));
}

fn add_bounds(
generics: &mut syn::Generics,
field_types: Vec<(syn::Type, bool)>,
field_types: Vec<(syn::Type, OpType, bool)>,
bound: syn::TypeParamBound,
defined_by_bound: syn::TypeParamBound,
add_ref: bool,
) {
let where_clause = if field_types.is_empty() {
return;
Expand All @@ -321,20 +324,59 @@ fn add_bounds(
})
};

for (f, is_defined_by) in field_types {
for (f, op_type, has_default) in field_types {
let (bounded_ty, required_bound) = match (op_type, add_ref) {
(OpType::Regular, _) => (f, bound.clone()),
(OpType::DefinedBy(_), _) => (f, defined_by_bound.clone()),

(OpType::Implicit(OpTypeArgs { value, required }), false) => {
let ty = if required || has_default {
syn::parse_quote!(asn1::Implicit::<#f, #value>)
} else {
syn::parse_quote!(asn1::Implicit::<<#f as asn1::OptionExt>::T, #value>)
};

(ty, bound.clone())
}
(OpType::Implicit(OpTypeArgs { value, required }), true) => {
let ty = if required || has_default {
syn::parse_quote!(for<'asn1_internal> asn1::Implicit::<&'asn1_internal #f, #value>)
} else {
syn::parse_quote!(for<'asn1_internal> asn1::Implicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value>)
};

(ty, bound.clone())
}

(OpType::Explicit(OpTypeArgs { value, required }), false) => {
let ty = if required || has_default {
syn::parse_quote!(asn1::Explicit::<#f, #value>)
} else {
syn::parse_quote!(asn1::Explicit::<<#f as asn1::OptionExt>::T, #value>)
};

(ty, bound.clone())
}
(OpType::Explicit(OpTypeArgs { value, required }), true) => {
let ty = if required || has_default {
syn::parse_quote!(for<'asn1_internal> asn1::Explicit::<&'asn1_internal #f, #value>)
} else {
syn::parse_quote!(for<'asn1_internal> asn1::Explicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value>)
};

(ty, bound.clone())
}
};

where_clause
.predicates
.push(syn::WherePredicate::Type(syn::PredicateType {
lifetimes: None,
bounded_ty: f,
bounded_ty,
colon_token: Default::default(),
bounds: {
let mut p = syn::punctuated::Punctuated::new();
if is_defined_by {
p.push(defined_by_bound.clone());
} else {
p.push(bound.clone());
}
p.push(required_bound);
p
},
}))
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,14 @@ pub fn write_defined_by<T: Asn1Writable, U: Asn1DefinedByWritable<T>>(
pub fn writable_defined_by_item<T: Asn1Writable, U: Asn1DefinedByWritable<T>>(v: &U) -> &T {
v.item()
}

/// Utility for use in `asn1_derive`. Not considered part of the public API.
#[doc(hidden)]
pub trait OptionExt {
type T;
}

#[doc(hidden)]
impl<T> OptionExt for Option<T> {
type T = T;
}
32 changes: 32 additions & 0 deletions tests/derive_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -758,4 +758,36 @@ fn test_perfect_derive() {
}

assert_roundtrips::<S<Op>>(&[(Ok(S { value: 12 }), b"\x30\x03\x02\x01\x0c")]);

#[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)]
struct TaggedRequiredFields<T: X> {
#[implicit(1, required)]
a: T::Type,
#[explicit(2, required)]
b: T::Type,
}

assert_roundtrips::<TaggedRequiredFields<Op>>(&[(
Ok(TaggedRequiredFields { a: 1, b: 3 }),
b"\x30\x08\x81\x01\x01\xa2\x03\x02\x01\x03",
)]);

#[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)]
struct TaggedOptionalFields<T: X> {
#[implicit(1)]
a: Option<T::Type>,
#[explicit(2)]
b: Option<T::Type>,
}

assert_roundtrips::<TaggedOptionalFields<Op>>(&[
(
Ok(TaggedOptionalFields {
a: Some(1),
b: Some(3),
}),
b"\x30\x08\x81\x01\x01\xa2\x03\x02\x01\x03",
),
(Ok(TaggedOptionalFields { a: None, b: None }), b"\x30\x00"),
]);
}

0 comments on commit f90f20d

Please sign in to comment.