@@ -8,6 +8,8 @@ use crate::deriving::generic::ty::*;
8
8
use crate :: deriving:: generic:: * ;
9
9
use crate :: deriving:: { path_local, path_std} ;
10
10
11
+ /// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
12
+ /// target item.
11
13
pub ( crate ) fn expand_deriving_partial_eq (
12
14
cx : & ExtCtxt < ' _ > ,
13
15
span : Span ,
@@ -16,62 +18,6 @@ pub(crate) fn expand_deriving_partial_eq(
16
18
push : & mut dyn FnMut ( Annotatable ) ,
17
19
is_const : bool ,
18
20
) {
19
- fn cs_eq ( cx : & ExtCtxt < ' _ > , span : Span , substr : & Substructure < ' _ > ) -> BlockOrExpr {
20
- let base = true ;
21
- let expr = cs_fold (
22
- true , // use foldl
23
- cx,
24
- span,
25
- substr,
26
- |cx, fold| match fold {
27
- CsFold :: Single ( field) => {
28
- let [ other_expr] = & field. other_selflike_exprs [ ..] else {
29
- cx. dcx ( )
30
- . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
31
- } ;
32
-
33
- // We received arguments of type `&T`. Convert them to type `T` by stripping
34
- // any leading `&`. This isn't necessary for type checking, but
35
- // it results in better error messages if something goes wrong.
36
- //
37
- // Note: for arguments that look like `&{ x }`, which occur with packed
38
- // structs, this would cause expressions like `{ self.x } == { other.x }`,
39
- // which isn't valid Rust syntax. This wouldn't break compilation because these
40
- // AST nodes are constructed within the compiler. But it would mean that code
41
- // printed by `-Zunpretty=expanded` (or `cargo expand`) would have invalid
42
- // syntax, which would be suboptimal. So we wrap these in parens, giving
43
- // `({ self.x }) == ({ other.x })`, which is valid syntax.
44
- let convert = |expr : & P < Expr > | {
45
- if let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) =
46
- & expr. kind
47
- {
48
- if let ExprKind :: Block ( ..) = & inner. kind {
49
- // `&{ x }` form: remove the `&`, add parens.
50
- cx. expr_paren ( field. span , inner. clone ( ) )
51
- } else {
52
- // `&x` form: remove the `&`.
53
- inner. clone ( )
54
- }
55
- } else {
56
- expr. clone ( )
57
- }
58
- } ;
59
- cx. expr_binary (
60
- field. span ,
61
- BinOpKind :: Eq ,
62
- convert ( & field. self_expr ) ,
63
- convert ( other_expr) ,
64
- )
65
- }
66
- CsFold :: Combine ( span, expr1, expr2) => {
67
- cx. expr_binary ( span, BinOpKind :: And , expr1, expr2)
68
- }
69
- CsFold :: Fieldless => cx. expr_bool ( span, base) ,
70
- } ,
71
- ) ;
72
- BlockOrExpr :: new_expr ( expr)
73
- }
74
-
75
21
let structural_trait_def = TraitDef {
76
22
span,
77
23
path : path_std ! ( marker:: StructuralPartialEq ) ,
@@ -97,7 +43,9 @@ pub(crate) fn expand_deriving_partial_eq(
97
43
ret_ty: Path ( path_local!( bool ) ) ,
98
44
attributes: thin_vec![ cx. attr_word( sym:: inline, span) ] ,
99
45
fieldless_variants_strategy: FieldlessVariantsStrategy :: Unify ,
100
- combine_substructure: combine_substructure( Box :: new( |a, b, c| cs_eq( a, b, c) ) ) ,
46
+ combine_substructure: combine_substructure( Box :: new( |a, b, c| {
47
+ BlockOrExpr :: new_expr( get_substructure_equality_expr( a, b, c) )
48
+ } ) ) ,
101
49
} ] ;
102
50
103
51
let trait_def = TraitDef {
@@ -113,3 +61,156 @@ pub(crate) fn expand_deriving_partial_eq(
113
61
} ;
114
62
trait_def. expand ( cx, mitem, item, push)
115
63
}
64
+
65
+ /// Generates the equality expression for a struct or enum variant when deriving
66
+ /// `PartialEq`.
67
+ ///
68
+ /// This function generates an expression that checks if all fields of a struct
69
+ /// or enum variant are equal.
70
+ /// - Scalar fields are compared first for efficiency, followed by compound
71
+ /// fields.
72
+ /// - If there are no fields, returns `true` (fieldless types are always equal).
73
+ ///
74
+ /// Whether a field is considered "scalar" is determined by comparing the symbol
75
+ /// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
76
+ /// This check is based on the type's symbol.
77
+ ///
78
+ /// ### Example 1
79
+ /// ```
80
+ /// #[derive(PartialEq)]
81
+ /// struct i32;
82
+ ///
83
+ /// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
84
+ /// // the primitive), it will not be treated as scalar. The function will still
85
+ /// // check equality of `field_2` first because the symbol matches `i32`.
86
+ /// #[derive(PartialEq)]
87
+ /// struct Struct {
88
+ /// field_1: &'static str,
89
+ /// field_2: i32,
90
+ /// }
91
+ /// ```
92
+ ///
93
+ /// ### Example 2
94
+ /// ```
95
+ /// mod ty {
96
+ /// pub type i32 = i32;
97
+ /// }
98
+ ///
99
+ /// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
100
+ /// // However, the function will not reorder the fields because the symbol for
101
+ /// // `ty::i32` does not match the symbol for the primitive `i32`
102
+ /// // ("ty::i32" != "i32").
103
+ /// #[derive(PartialEq)]
104
+ /// struct Struct {
105
+ /// field_1: &'static str,
106
+ /// field_2: ty::i32,
107
+ /// }
108
+ /// ```
109
+ ///
110
+ /// For enums, the discriminant is compared first, then the rest of the fields.
111
+ ///
112
+ /// # Panics
113
+ ///
114
+ /// If called on static or all-fieldless enums/structs, which should not occur
115
+ /// during derive expansion.
116
+ fn get_substructure_equality_expr (
117
+ cx : & ExtCtxt < ' _ > ,
118
+ span : Span ,
119
+ substructure : & Substructure < ' _ > ,
120
+ ) -> P < Expr > {
121
+ use SubstructureFields :: * ;
122
+
123
+ match substructure. fields {
124
+ EnumMatching ( .., fields) | Struct ( .., fields) => {
125
+ let combine = move |acc, field| {
126
+ let rhs = get_field_equality_expr ( cx, field) ;
127
+ if let Some ( lhs) = acc {
128
+ // Combine the previous comparison with the current field
129
+ // using logical AND.
130
+ return Some ( cx. expr_binary ( field. span , BinOpKind :: And , lhs, rhs) ) ;
131
+ }
132
+ // Start the chain with the first field's comparison.
133
+ Some ( rhs)
134
+ } ;
135
+
136
+ // First compare scalar fields, then compound fields, combining all
137
+ // with logical AND.
138
+ return fields
139
+ . iter ( )
140
+ . filter ( |field| !field. maybe_scalar )
141
+ . fold ( fields. iter ( ) . filter ( |field| field. maybe_scalar ) . fold ( None , combine) , combine)
142
+ // If there are no fields, treat as always equal.
143
+ . unwrap_or_else ( || cx. expr_bool ( span, true ) ) ;
144
+ }
145
+ EnumDiscr ( disc, match_expr) => {
146
+ let lhs = get_field_equality_expr ( cx, disc) ;
147
+ let Some ( match_expr) = match_expr else {
148
+ return lhs;
149
+ } ;
150
+ // Compare the discriminant first (cheaper), then the rest of the
151
+ // fields.
152
+ return cx. expr_binary ( disc. span , BinOpKind :: And , lhs, match_expr. clone ( ) ) ;
153
+ }
154
+ StaticEnum ( ..) => cx. dcx ( ) . span_bug (
155
+ span,
156
+ "unexpected static enum encountered during `derive(PartialEq)` expansion" ,
157
+ ) ,
158
+ StaticStruct ( ..) => cx. dcx ( ) . span_bug (
159
+ span,
160
+ "unexpected static struct encountered during `derive(PartialEq)` expansion" ,
161
+ ) ,
162
+ AllFieldlessEnum ( ..) => cx. dcx ( ) . span_bug (
163
+ span,
164
+ "unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion" ,
165
+ ) ,
166
+ }
167
+ }
168
+
169
+ /// Generates an equality comparison expression for a single struct or enum
170
+ /// field.
171
+ ///
172
+ /// This function produces an AST expression that compares the `self` and
173
+ /// `other` values for a field using `==`. It removes any leading references
174
+ /// from both sides for readability. If the field is a block expression, it is
175
+ /// wrapped in parentheses to ensure valid syntax.
176
+ ///
177
+ /// # Panics
178
+ ///
179
+ /// Panics if there are not exactly two arguments to compare (should be `self`
180
+ /// and `other`).
181
+ fn get_field_equality_expr ( cx : & ExtCtxt < ' _ > , field : & FieldInfo ) -> P < Expr > {
182
+ let [ rhs] = & field. other_selflike_exprs [ ..] else {
183
+ cx. dcx ( ) . span_bug ( field. span , "not exactly 2 arguments in `derive(PartialEq)`" ) ;
184
+ } ;
185
+
186
+ cx. expr_binary (
187
+ field. span ,
188
+ BinOpKind :: Eq ,
189
+ wrap_block_expr ( cx, peel_refs ( & field. self_expr ) ) ,
190
+ wrap_block_expr ( cx, peel_refs ( rhs) ) ,
191
+ )
192
+ }
193
+
194
+ /// Removes all leading immutable references from an expression.
195
+ ///
196
+ /// This is used to strip away any number of leading `&` from an expression
197
+ /// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
198
+ /// references are preserved.
199
+ fn peel_refs ( mut expr : & P < Expr > ) -> P < Expr > {
200
+ while let ExprKind :: AddrOf ( BorrowKind :: Ref , Mutability :: Not , inner) = & expr. kind {
201
+ expr = & inner;
202
+ }
203
+ expr. clone ( )
204
+ }
205
+
206
+ /// Wraps a block expression in parentheses to ensure valid AST in macro
207
+ /// expansion output.
208
+ ///
209
+ /// If the given expression is a block, it is wrapped in parentheses; otherwise,
210
+ /// it is returned unchanged.
211
+ fn wrap_block_expr ( cx : & ExtCtxt < ' _ > , expr : P < Expr > ) -> P < Expr > {
212
+ if matches ! ( & expr. kind, ExprKind :: Block ( ..) ) {
213
+ return cx. expr_paren ( expr. span , expr) ;
214
+ }
215
+ expr
216
+ }
0 commit comments