diff --git a/asn1_derive/src/lib.rs b/asn1_derive/src/lib.rs index 47370d7..2222ff0 100644 --- a/asn1_derive/src/lib.rs +++ b/asn1_derive/src/lib.rs @@ -17,6 +17,7 @@ pub fn derive_asn1_read(input: proc_macro::TokenStream) -> proc_macro::TokenStre all_field_types(&input.data), syn::parse_quote!(asn1::Asn1Readable<#lifetime_name>), syn::parse_quote!(asn1::Asn1DefinedByReadable<#lifetime_name, asn1::ObjectIdentifier>), + false, ); let (impl_generics, _, where_clause) = generics.split_for_impl(); @@ -65,6 +66,7 @@ pub fn derive_asn1_write(input: proc_macro::TokenStream) -> proc_macro::TokenStr all_field_types(&input.data), syn::parse_quote!(asn1::Asn1Writable), syn::parse_quote!(asn1::Asn1DefinedByWritable), + true, ); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -267,7 +269,7 @@ fn add_lifetime_if_none(generics: &mut syn::Generics) -> syn::Lifetime { generics.lifetimes().next().unwrap().lifetime.clone() } -fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> { +fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, OpType, bool)> { let mut field_types = vec![]; match data { syn::Data::Struct(v) => { @@ -283,7 +285,7 @@ fn all_field_types(data: &syn::Data) -> Vec<(syn::Type, bool)> { field_types } -fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Fields) { +fn add_field_types(field_types: &mut Vec<(syn::Type, OpType, bool)>, fields: &syn::Fields) { match fields { syn::Fields::Named(v) => { for f in &v.named { @@ -299,16 +301,17 @@ fn add_field_types(field_types: &mut Vec<(syn::Type, bool)>, fields: &syn::Field } } -fn add_field_type(field_types: &mut Vec<(syn::Type, bool)>, f: &syn::Field) { - let (op_type, _) = extract_field_properties(&f.attrs); - field_types.push((f.ty.clone(), matches!(op_type, OpType::DefinedBy(_)))); +fn add_field_type(field_types: &mut Vec<(syn::Type, OpType, bool)>, f: &syn::Field) { + let (op_type, default) = extract_field_properties(&f.attrs); + field_types.push((f.ty.clone(), op_type, default.is_some())); } fn add_bounds( generics: &mut syn::Generics, - field_types: Vec<(syn::Type, bool)>, + field_types: Vec<(syn::Type, OpType, bool)>, bound: syn::TypeParamBound, defined_by_bound: syn::TypeParamBound, + add_ref: bool, ) { let where_clause = if field_types.is_empty() { return; @@ -321,20 +324,59 @@ fn add_bounds( }) }; - for (f, is_defined_by) in field_types { + for (f, op_type, has_default) in field_types { + let (bounded_ty, required_bound) = match (op_type, add_ref) { + (OpType::Regular, _) => (f, bound.clone()), + (OpType::DefinedBy(_), _) => (f, defined_by_bound.clone()), + + (OpType::Implicit(OpTypeArgs { value, required }), false) => { + let ty = if required || has_default { + syn::parse_quote!(asn1::Implicit::<#f, #value>) + } else { + syn::parse_quote!(asn1::Implicit::<<#f as asn1::OptionExt>::T, #value>) + }; + + (ty, bound.clone()) + } + (OpType::Implicit(OpTypeArgs { value, required }), true) => { + let ty = if required || has_default { + syn::parse_quote!(for<'asn1_internal> asn1::Implicit::<&'asn1_internal #f, #value>) + } else { + syn::parse_quote!(for<'asn1_internal> asn1::Implicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value>) + }; + + (ty, bound.clone()) + } + + (OpType::Explicit(OpTypeArgs { value, required }), false) => { + let ty = if required || has_default { + syn::parse_quote!(asn1::Explicit::<#f, #value>) + } else { + syn::parse_quote!(asn1::Explicit::<<#f as asn1::OptionExt>::T, #value>) + }; + + (ty, bound.clone()) + } + (OpType::Explicit(OpTypeArgs { value, required }), true) => { + let ty = if required || has_default { + syn::parse_quote!(for<'asn1_internal> asn1::Explicit::<&'asn1_internal #f, #value>) + } else { + syn::parse_quote!(for<'asn1_internal> asn1::Explicit::<&'asn1_internal <#f as asn1::OptionExt>::T, #value>) + }; + + (ty, bound.clone()) + } + }; + where_clause .predicates .push(syn::WherePredicate::Type(syn::PredicateType { lifetimes: None, - bounded_ty: f, + bounded_ty, colon_token: Default::default(), bounds: { let mut p = syn::punctuated::Punctuated::new(); - if is_defined_by { - p.push(defined_by_bound.clone()); - } else { - p.push(bound.clone()); - } + p.push(required_bound); p }, })) diff --git a/src/lib.rs b/src/lib.rs index d573109..a62bb3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -231,3 +231,14 @@ pub fn write_defined_by>( pub fn writable_defined_by_item>(v: &U) -> &T { v.item() } + +/// Utility for use in `asn1_derive`. Not considered part of the public API. +#[doc(hidden)] +pub trait OptionExt { + type T; +} + +#[doc(hidden)] +impl OptionExt for Option { + type T = T; +} diff --git a/tests/derive_test.rs b/tests/derive_test.rs index 8d58493..f5b00f4 100644 --- a/tests/derive_test.rs +++ b/tests/derive_test.rs @@ -758,4 +758,36 @@ fn test_perfect_derive() { } assert_roundtrips::>(&[(Ok(S { value: 12 }), b"\x30\x03\x02\x01\x0c")]); + + #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + struct TaggedRequiredFields { + #[implicit(1, required)] + a: T::Type, + #[explicit(2, required)] + b: T::Type, + } + + assert_roundtrips::>(&[( + Ok(TaggedRequiredFields { a: 1, b: 3 }), + b"\x30\x08\x81\x01\x01\xa2\x03\x02\x01\x03", + )]); + + #[derive(asn1::Asn1Read, asn1::Asn1Write, PartialEq, Debug, Eq)] + struct TaggedOptionalFields { + #[implicit(1)] + a: Option, + #[explicit(2)] + b: Option, + } + + assert_roundtrips::>(&[ + ( + Ok(TaggedOptionalFields { + a: Some(1), + b: Some(3), + }), + b"\x30\x08\x81\x01\x01\xa2\x03\x02\x01\x03", + ), + (Ok(TaggedOptionalFields { a: None, b: None }), b"\x30\x00"), + ]); }