99use proc_macro2:: TokenStream ;
1010use quote:: quote;
1111use syn:: {
12- parse_quote, spanned:: Spanned as _, DataEnum , DeriveInput , Error , Fields , Generics , Ident , Path ,
12+ parse_quote, spanned:: Spanned as _, DataEnum , DeriveInput , Error , Fields , Generics , Ident ,
13+ Index , Path ,
1314} ;
1415
1516use crate :: {
16- derive_try_from_bytes_inner, repr:: EnumRepr , DataExt , FieldBounds , ImplBlockBuilder , Trait ,
17+ derive_has_field_struct_union, derive_try_from_bytes_inner, repr:: EnumRepr , DataExt ,
18+ FieldBounds , ImplBlockBuilder , Trait ,
1719} ;
1820
1921/// Generates a tag enum for the given enum. This generates an enum with the
@@ -162,7 +164,18 @@ fn generate_variant_structs(
162164 }
163165}
164166
165- fn generate_variants_union ( generics : & Generics , data : & DataEnum ) -> TokenStream {
167+ fn variants_union_field_ident ( ident : & Ident ) -> Ident {
168+ let variant_ident_str = crate :: ext:: to_ident_str ( ident) ;
169+ // Field names are prefixed with `__field_` to prevent name collision
170+ // with the `__nonempty` field.
171+ Ident :: new ( & format ! ( "__field_{}" , variant_ident_str) , ident. span ( ) )
172+ }
173+
174+ fn generate_variants_union (
175+ generics : & Generics ,
176+ data : & DataEnum ,
177+ zerocopy_crate : & Path ,
178+ ) -> TokenStream {
166179 let ( _, ty_generics, _) = generics. split_for_impl ( ) ;
167180
168181 let fields = data. variants . iter ( ) . filter_map ( |variant| {
@@ -172,10 +185,7 @@ fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream
172185 return None ;
173186 }
174187
175- // Field names are prefixed with `__field_` to prevent name collision
176- // with the `__nonempty` field.
177- let field_name_str = crate :: ext:: to_ident_str ( & variant. ident ) ;
178- let field_name = Ident :: new ( & format ! ( "__field_{}" , field_name_str) , variant. ident . span ( ) ) ;
188+ let field_name = variants_union_field_ident ( & variant. ident ) ;
179189 let variant_struct_ident = variant_struct_ident ( & variant. ident ) ;
180190
181191 Some ( quote ! {
@@ -185,7 +195,7 @@ fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream
185195 } )
186196 } ) ;
187197
188- quote ! {
198+ let variants_union = parse_quote ! {
189199 #[ repr( C ) ]
190200 #[ allow( non_snake_case) ]
191201 union ___ZerocopyVariants #generics {
@@ -197,6 +207,14 @@ fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream
197207 // affect the layout.
198208 __nonempty: ( ) ,
199209 }
210+ } ;
211+
212+ let has_field =
213+ derive_has_field_struct_union ( & variants_union, & variants_union. data , zerocopy_crate) ;
214+
215+ quote ! {
216+ #variants_union
217+ #has_field
200218 }
201219}
202220
@@ -246,18 +264,20 @@ pub(crate) fn derive_is_bit_valid(
246264 } ;
247265
248266 let variant_structs = generate_variant_structs ( enum_ident, generics, data, zerocopy_crate) ;
249- let variants_union = generate_variants_union ( generics, data) ;
267+ let variants_union = generate_variants_union ( generics, data, zerocopy_crate ) ;
250268
251- let ( _, ty_generics, _) = generics. split_for_impl ( ) ;
269+ let ( _, ref ty_generics, _) = generics. split_for_impl ( ) ;
252270
253271 let has_fields = data. variants ( ) . into_iter ( ) . flat_map ( |( variant, fields) | {
254272 let variant_ident = & variant. unwrap ( ) . ident ;
273+ let variants_union_field_ident = variants_union_field_ident ( variant_ident) ;
255274 let field: Box < syn:: Type > = parse_quote ! ( ( ) ) ;
256- fields. into_iter ( ) . map ( move |( vis, ident, ty) | {
275+ fields. into_iter ( ) . enumerate ( ) . map ( move |( idx , ( vis, ident, ty) ) | {
257276 // Rust does not presently support explicit visibility modifiers on
258277 // enum fields, but we guard against the possibility to ensure this
259278 // derive remains sound.
260279 assert ! ( matches!( vis, syn:: Visibility :: Inherited ) ) ;
280+ let variant_struct_field_index = Index :: from ( idx + 1 ) ;
261281 ImplBlockBuilder :: new (
262282 ast,
263283 data,
@@ -274,6 +294,18 @@ pub(crate) fn derive_is_bit_valid(
274294 )
275295 . inner_extras ( quote ! {
276296 type Type = #ty;
297+
298+ #[ inline( always) ]
299+ fn project( slf: #zerocopy_crate:: PtrInner <' _, Self >) -> #zerocopy_crate:: PtrInner <' _, Self :: Type > {
300+ // SAFETY: By invariant on `___ZerocopyRawEnum`,
301+ // `___ZerocopyRawEnum` has the same layout as `Self`.
302+ let slf = unsafe { slf. cast:: <___ZerocopyRawEnum #ty_generics>( ) } ;
303+
304+ slf. project:: <_, 0 , { #zerocopy_crate:: ident_id!( variants) } >( )
305+ . project:: <_, 0 , { #zerocopy_crate:: ident_id!( #variants_union_field_ident) } >( )
306+ . project:: <_, 0 , { #zerocopy_crate:: ident_id!( value) } >( )
307+ . project:: <_, 0 , { #zerocopy_crate:: ident_id!( #variant_struct_field_index) } >( )
308+ }
277309 } )
278310 . build ( )
279311 } )
@@ -323,6 +355,22 @@ pub(crate) fn derive_is_bit_valid(
323355 }
324356 } ) ;
325357
358+ let raw_enum = parse_quote ! {
359+ #[ repr( C ) ]
360+ struct ___ZerocopyRawEnum #generics {
361+ tag: ___ZerocopyOuterTag,
362+ variants: ___ZerocopyVariants #ty_generics,
363+ }
364+ } ;
365+
366+ let raw_enum_projections =
367+ derive_has_field_struct_union ( & raw_enum, & raw_enum. data , zerocopy_crate) ;
368+
369+ let raw_enum = quote ! {
370+ #raw_enum
371+ #raw_enum_projections
372+ } ;
373+
326374 Ok ( quote ! {
327375 // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the
328376 // enum's tag corresponds to one of the enum's discriminants. Then, we
@@ -351,11 +399,7 @@ pub(crate) fn derive_is_bit_valid(
351399
352400 #variants_union
353401
354- #[ repr( C ) ]
355- struct ___ZerocopyRawEnum #generics {
356- tag: ___ZerocopyOuterTag,
357- variants: ___ZerocopyVariants #ty_generics,
358- }
402+ #raw_enum
359403
360404 #( #has_fields) *
361405
@@ -399,26 +443,23 @@ pub(crate) fn derive_is_bit_valid(
399443 // invariant from `p`, so we re-assert that all of the bytes are
400444 // initialized.
401445 let raw_enum = unsafe { raw_enum. assume_initialized( ) } ;
446+
402447 // SAFETY:
403- // - This projection returns a subfield of `this ` using
404- // `addr_of_mut! `.
405- // - Because the subfield pointer is derived from `this `, it has the
406- // same provenance.
448+ // - This projection returns a subfield of `raw_enum ` using
449+ // `project `.
450+ // - Because the subfield pointer is derived from `raw_enum `, it has
451+ // the same provenance.
407452 // - The locations of `UnsafeCell`s in the subfield match the
408- // locations of `UnsafeCell`s in `this `. This is because the
453+ // locations of `UnsafeCell`s in `raw_enum `. This is because the
409454 // subfield pointer just points to a smaller portion of the
410455 // overall struct.
456+ let project = #zerocopy_crate:: pointer:: PtrInner :: project:: <
457+ _,
458+ 0 ,
459+ { #zerocopy_crate:: ident_id!( variants) }
460+ >;
411461 let variants = unsafe {
412- use #zerocopy_crate:: pointer:: PtrInner ;
413- raw_enum. cast_unsized_unchecked( |p: PtrInner <' _, ___ZerocopyRawEnum #ty_generics>| {
414- let p = p. as_non_null( ) . as_ptr( ) ;
415- let ptr = core_reexport:: ptr:: addr_of_mut!( ( * p) . variants) ;
416- // SAFETY: `ptr` is a projection into `p`, which is
417- // `NonNull`, and guaranteed not to wrap around the address
418- // space. Thus, `ptr` cannot be null.
419- let ptr = unsafe { core_reexport:: ptr:: NonNull :: new_unchecked( ptr) } ;
420- unsafe { PtrInner :: new( ptr) }
421- } )
462+ raw_enum. cast_unsized_unchecked( project)
422463 } ;
423464
424465 #[ allow( non_upper_case_globals) ]
0 commit comments