Skip to content

Commit 0e9fa88

Browse files
committed
Make VariantDiscriminant derive macro support variants with non-'static lifetimes
1 parent c490bd9 commit 0e9fa88

File tree

8 files changed

+305
-15
lines changed

8 files changed

+305
-15
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ Please make sure to add your changes to the appropriate categories:
2020

2121
### Added
2222

23-
- n/a
23+
- Made `VariantDiscriminant` derive macro support variants with non-`'static` lifetimes.
2424

2525
### Changed
2626

macros/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ proc-macro = true
2020
[dependencies]
2121
proc-macro2 = { version = "1.0.81", features = ["span-locations"] }
2222
quote = "1.0.36"
23-
syn = { version = "2.0.60", features = ["full", "visit"] }
23+
syn = { version = "2.0.60", features = ["full", "visit", "visit-mut"] }

macros/src/enum_deriver.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use proc_macro2::TokenStream as TokenStream2;
22
use quote::quote;
33
use syn::{
4-
parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, Fields, Type, Variant,
4+
parse_quote, parse_quote_spanned, spanned::Spanned, visit::Visit as _, visit_mut::VisitMut,
5+
Fields, Type, Variant,
56
};
67

7-
use crate::*;
8+
use crate::{type_visitor_mut::TypeVisitorMut, *};
89

910
pub(crate) struct EnumDeriver {
1011
item: syn::ItemEnum,
@@ -585,15 +586,23 @@ impl EnumDeriver {
585586
match nested {
586587
NestedDiscriminantType::Default => {
587588
let (field, _) = field_selection.expect("no selected field found");
588-
let field_type = &field.ty;
589589

590-
if self.uses_generic_const_or_type(field_type) {
590+
let mut visitor = TypeVisitor::new(&self.item.generics);
591+
visitor.visit_type(&field.ty);
592+
593+
if visitor.type_uses_const_or_type_param() {
591594
return Err(syn::Error::new(
592595
field.span(),
593596
"generic fields require an explicit nested discriminant type",
594597
));
595598
}
596599

600+
let field_type = if visitor.type_uses_lifetime_param() {
601+
self.type_replacing_lifetimes_with_static(&field.ty)
602+
} else {
603+
field.ty.clone()
604+
};
605+
597606
let nested_type = parse_quote! {
598607
<#field_type as ::enumcapsulate::VariantDiscriminant>::Discriminant
599608
};
@@ -741,11 +750,18 @@ impl EnumDeriver {
741750
}
742751
}
743752

753+
fn type_replacing_lifetimes_with_static(&self, ty: &syn::Type) -> syn::Type {
754+
let mut ty = ty.clone();
755+
let mut visitor = TypeVisitorMut::default().replace_lifetimes_with_static();
756+
visitor.visit_type_mut(&mut ty);
757+
ty
758+
}
759+
744760
fn uses_generic_const_or_type(&self, ty: &syn::Type) -> bool {
745761
let mut visitor = TypeVisitor::new(&self.item.generics);
746762

747763
visitor.visit_type(ty);
748764

749-
visitor.uses_const_or_type_param()
765+
visitor.type_uses_const_or_type_param()
750766
}
751767
}

macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::utils::tokenstream;
66
mod config;
77
mod enum_deriver;
88
mod type_visitor;
9+
mod type_visitor_mut;
910
mod utils;
1011

1112
use self::{config::*, enum_deriver::*, type_visitor::*, utils::*};

macros/src/type_visitor.rs

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,48 @@ use std::collections::HashSet;
33
use syn::visit::Visit;
44

55
pub struct TypeVisitor<'ast> {
6+
lifetime_param_idents: HashSet<&'ast syn::Ident>,
67
const_param_idents: HashSet<&'ast syn::Ident>,
78
type_param_idents: HashSet<&'ast syn::Ident>,
89

9-
uses_const_param: bool,
10-
uses_type_param: bool,
10+
type_uses_lifetime_param: bool,
11+
type_uses_const_param: bool,
12+
type_uses_type_param: bool,
1113
}
1214

1315
impl<'ast> TypeVisitor<'ast> {
1416
pub fn new(generics: &'ast syn::Generics) -> Self {
1517
Self {
18+
lifetime_param_idents: generics
19+
.lifetimes()
20+
.map(|param| &param.lifetime.ident)
21+
.collect(),
1622
const_param_idents: generics.const_params().map(|param| &param.ident).collect(),
1723
type_param_idents: generics.type_params().map(|param| &param.ident).collect(),
18-
uses_const_param: false,
19-
uses_type_param: false,
24+
type_uses_lifetime_param: false,
25+
type_uses_const_param: false,
26+
type_uses_type_param: false,
2027
}
2128
}
2229

23-
pub fn uses_const_or_type_param(self) -> bool {
24-
self.uses_const_param || self.uses_type_param
30+
#[allow(dead_code)]
31+
pub fn type_uses_lifetime_param(&self) -> bool {
32+
self.type_uses_lifetime_param
33+
}
34+
35+
#[allow(dead_code)]
36+
pub fn type_uses_const_param(&self) -> bool {
37+
self.type_uses_const_param
38+
}
39+
40+
#[allow(dead_code)]
41+
pub fn type_uses_type_param(&self) -> bool {
42+
self.type_uses_type_param
43+
}
44+
45+
#[allow(dead_code)]
46+
pub fn type_uses_const_or_type_param(&self) -> bool {
47+
self.type_uses_const_param || self.type_uses_type_param
2548
}
2649
}
2750

@@ -39,12 +62,18 @@ impl<'ast> Visit<'ast> for TypeVisitor<'ast> {
3962
let ident = &path_segment.ident;
4063

4164
if self.type_param_idents.contains(ident) {
42-
self.uses_type_param = true;
65+
self.type_uses_type_param = true;
4366
} else if self.const_param_idents.contains(ident) {
44-
self.uses_const_param = true;
67+
self.type_uses_const_param = true;
4568
}
4669
}
4770
}
4871
syn::visit::visit_type_path(self, node);
4972
}
73+
74+
fn visit_lifetime(&mut self, lifetime: &'ast syn::Lifetime) {
75+
if self.lifetime_param_idents.contains(&lifetime.ident) {
76+
self.type_uses_lifetime_param = true;
77+
}
78+
}
5079
}

macros/src/type_visitor_mut.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
use syn::visit_mut::VisitMut;
2+
3+
pub struct TypeVisitorMut {
4+
replace_lifetimes_with_static: bool,
5+
}
6+
7+
impl TypeVisitorMut {
8+
pub fn new() -> Self {
9+
Self {
10+
replace_lifetimes_with_static: false,
11+
}
12+
}
13+
14+
pub fn replace_lifetimes_with_static(mut self) -> Self {
15+
self.replace_lifetimes_with_static = true;
16+
self
17+
}
18+
}
19+
20+
impl Default for TypeVisitorMut {
21+
fn default() -> Self {
22+
Self::new()
23+
}
24+
}
25+
26+
impl VisitMut for TypeVisitorMut {
27+
fn visit_lifetime_mut(&mut self, lifetime: &mut syn::Lifetime) {
28+
lifetime.ident = quote::format_ident!("static");
29+
}
30+
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
use enumcapsulate::VariantDiscriminant;
2+
pub enum VariantWithLifetime<'a> {
3+
Variant(&'a ()),
4+
}
5+
pub enum VariantWithLifetimeDiscriminant {
6+
Variant,
7+
}
8+
#[automatically_derived]
9+
impl ::core::marker::Copy for VariantWithLifetimeDiscriminant {}
10+
#[automatically_derived]
11+
impl ::core::clone::Clone for VariantWithLifetimeDiscriminant {
12+
#[inline]
13+
fn clone(&self) -> VariantWithLifetimeDiscriminant {
14+
*self
15+
}
16+
}
17+
#[automatically_derived]
18+
impl ::core::cmp::Ord for VariantWithLifetimeDiscriminant {
19+
#[inline]
20+
fn cmp(&self, other: &VariantWithLifetimeDiscriminant) -> ::core::cmp::Ordering {
21+
::core::cmp::Ordering::Equal
22+
}
23+
}
24+
#[automatically_derived]
25+
impl ::core::cmp::PartialOrd for VariantWithLifetimeDiscriminant {
26+
#[inline]
27+
fn partial_cmp(
28+
&self,
29+
other: &VariantWithLifetimeDiscriminant,
30+
) -> ::core::option::Option<::core::cmp::Ordering> {
31+
::core::option::Option::Some(::core::cmp::Ordering::Equal)
32+
}
33+
}
34+
#[automatically_derived]
35+
impl ::core::cmp::Eq for VariantWithLifetimeDiscriminant {
36+
#[inline]
37+
#[doc(hidden)]
38+
#[coverage(off)]
39+
fn assert_receiver_is_total_eq(&self) -> () {}
40+
}
41+
#[automatically_derived]
42+
impl ::core::marker::StructuralPartialEq for VariantWithLifetimeDiscriminant {}
43+
#[automatically_derived]
44+
impl ::core::cmp::PartialEq for VariantWithLifetimeDiscriminant {
45+
#[inline]
46+
fn eq(&self, other: &VariantWithLifetimeDiscriminant) -> bool {
47+
true
48+
}
49+
}
50+
#[automatically_derived]
51+
impl ::core::hash::Hash for VariantWithLifetimeDiscriminant {
52+
#[inline]
53+
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {}
54+
}
55+
#[automatically_derived]
56+
impl ::core::fmt::Debug for VariantWithLifetimeDiscriminant {
57+
#[inline]
58+
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
59+
::core::fmt::Formatter::write_str(f, "Variant")
60+
}
61+
}
62+
impl<'a> ::enumcapsulate::VariantDiscriminant for VariantWithLifetime<'a> {
63+
type Discriminant = VariantWithLifetimeDiscriminant;
64+
fn variant_discriminant(&self) -> Self::Discriminant {
65+
match self {
66+
VariantWithLifetime::Variant(..) => VariantWithLifetimeDiscriminant::Variant,
67+
_ => ::core::panicking::panic("internal error: entered unreachable code"),
68+
}
69+
}
70+
}
71+
pub enum EnumWithLifetime<'a> {
72+
#[enumcapsulate(discriminant(nested))]
73+
VariantA(VariantWithLifetime<'a>),
74+
}
75+
pub enum EnumWithLifetimeDiscriminant {
76+
VariantA(
77+
<VariantWithLifetime<
78+
'static,
79+
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
80+
),
81+
}
82+
#[automatically_derived]
83+
impl ::core::marker::Copy for EnumWithLifetimeDiscriminant {}
84+
#[automatically_derived]
85+
impl ::core::clone::Clone for EnumWithLifetimeDiscriminant {
86+
#[inline]
87+
fn clone(&self) -> EnumWithLifetimeDiscriminant {
88+
let _: ::core::clone::AssertParamIsClone<
89+
<VariantWithLifetime<
90+
'static,
91+
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
92+
>;
93+
*self
94+
}
95+
}
96+
#[automatically_derived]
97+
impl ::core::cmp::Ord for EnumWithLifetimeDiscriminant {
98+
#[inline]
99+
fn cmp(&self, other: &EnumWithLifetimeDiscriminant) -> ::core::cmp::Ordering {
100+
match (self, other) {
101+
(
102+
EnumWithLifetimeDiscriminant::VariantA(__self_0),
103+
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
104+
) => ::core::cmp::Ord::cmp(__self_0, __arg1_0),
105+
}
106+
}
107+
}
108+
#[automatically_derived]
109+
impl ::core::cmp::PartialOrd for EnumWithLifetimeDiscriminant {
110+
#[inline]
111+
fn partial_cmp(
112+
&self,
113+
other: &EnumWithLifetimeDiscriminant,
114+
) -> ::core::option::Option<::core::cmp::Ordering> {
115+
match (self, other) {
116+
(
117+
EnumWithLifetimeDiscriminant::VariantA(__self_0),
118+
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
119+
) => ::core::cmp::PartialOrd::partial_cmp(__self_0, __arg1_0),
120+
}
121+
}
122+
}
123+
#[automatically_derived]
124+
impl ::core::cmp::Eq for EnumWithLifetimeDiscriminant {
125+
#[inline]
126+
#[doc(hidden)]
127+
#[coverage(off)]
128+
fn assert_receiver_is_total_eq(&self) -> () {
129+
let _: ::core::cmp::AssertParamIsEq<
130+
<VariantWithLifetime<
131+
'static,
132+
> as ::enumcapsulate::VariantDiscriminant>::Discriminant,
133+
>;
134+
}
135+
}
136+
#[automatically_derived]
137+
impl ::core::marker::StructuralPartialEq for EnumWithLifetimeDiscriminant {}
138+
#[automatically_derived]
139+
impl ::core::cmp::PartialEq for EnumWithLifetimeDiscriminant {
140+
#[inline]
141+
fn eq(&self, other: &EnumWithLifetimeDiscriminant) -> bool {
142+
match (self, other) {
143+
(
144+
EnumWithLifetimeDiscriminant::VariantA(__self_0),
145+
EnumWithLifetimeDiscriminant::VariantA(__arg1_0),
146+
) => *__self_0 == *__arg1_0,
147+
}
148+
}
149+
}
150+
#[automatically_derived]
151+
impl ::core::hash::Hash for EnumWithLifetimeDiscriminant {
152+
#[inline]
153+
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {
154+
match self {
155+
EnumWithLifetimeDiscriminant::VariantA(__self_0) => {
156+
::core::hash::Hash::hash(__self_0, state)
157+
}
158+
}
159+
}
160+
}
161+
#[automatically_derived]
162+
impl ::core::fmt::Debug for EnumWithLifetimeDiscriminant {
163+
#[inline]
164+
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
165+
match self {
166+
EnumWithLifetimeDiscriminant::VariantA(__self_0) => {
167+
::core::fmt::Formatter::debug_tuple_field1_finish(
168+
f,
169+
"VariantA",
170+
&__self_0,
171+
)
172+
}
173+
}
174+
}
175+
}
176+
impl<'a> ::enumcapsulate::VariantDiscriminant for EnumWithLifetime<'a> {
177+
type Discriminant = EnumWithLifetimeDiscriminant;
178+
fn variant_discriminant(&self) -> Self::Discriminant {
179+
match self {
180+
EnumWithLifetime::VariantA(inner, ..) => {
181+
EnumWithLifetimeDiscriminant::VariantA(inner.variant_discriminant())
182+
}
183+
_ => ::core::panicking::panic("internal error: entered unreachable code"),
184+
}
185+
}
186+
}
187+
fn main() {}

0 commit comments

Comments
 (0)