Skip to content

Commit 136f1cd

Browse files
committed
Compute closure captures
1 parent ea22d24 commit 136f1cd

40 files changed

+2323
-428
lines changed

crates/hir-def/src/body/lower.rs

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ use crate::{
3232
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinUint},
3333
db::DefDatabase,
3434
expr::{
35-
dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId,
36-
FloatTypeWrapper, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId,
35+
dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Expr,
36+
ExprId, FloatTypeWrapper, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId,
3737
RecordFieldPat, RecordLitField, Statement,
3838
},
3939
item_scope::BuiltinShadowMode,
@@ -105,6 +105,7 @@ pub(super) fn lower(
105105
current_try_block: None,
106106
is_lowering_assignee_expr: false,
107107
is_lowering_generator: false,
108+
current_binding_owner: None,
108109
}
109110
.collect(params, body, is_async_fn)
110111
}
@@ -119,6 +120,7 @@ struct ExprCollector<'a> {
119120
current_try_block: Option<LabelId>,
120121
is_lowering_assignee_expr: bool,
121122
is_lowering_generator: bool,
123+
current_binding_owner: Option<ExprId>,
122124
}
123125

124126
#[derive(Debug, Default)]
@@ -218,7 +220,12 @@ impl ExprCollector<'_> {
218220
}
219221

220222
fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId {
221-
self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() })
223+
self.body.bindings.alloc(Binding {
224+
name,
225+
mode,
226+
definitions: SmallVec::new(),
227+
owner: self.current_binding_owner,
228+
})
222229
}
223230
fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId {
224231
let src = self.expander.to_source(ptr);
@@ -497,6 +504,8 @@ impl ExprCollector<'_> {
497504
}
498505
}
499506
ast::Expr::ClosureExpr(e) => {
507+
let (result_expr_id, prev_binding_owner) =
508+
self.initialize_binding_owner(syntax_ptr);
500509
let mut args = Vec::new();
501510
let mut arg_types = Vec::new();
502511
if let Some(pl) = e.param_list() {
@@ -530,18 +539,19 @@ impl ExprCollector<'_> {
530539
} else {
531540
ClosureKind::Closure
532541
};
542+
let capture_by =
543+
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
533544
self.is_lowering_generator = prev_is_lowering_generator;
534-
535-
self.alloc_expr(
536-
Expr::Closure {
537-
args: args.into(),
538-
arg_types: arg_types.into(),
539-
ret_type,
540-
body,
541-
closure_kind,
542-
},
543-
syntax_ptr,
544-
)
545+
self.current_binding_owner = prev_binding_owner;
546+
self.body.exprs[result_expr_id] = Expr::Closure {
547+
args: args.into(),
548+
arg_types: arg_types.into(),
549+
ret_type,
550+
body,
551+
closure_kind,
552+
capture_by,
553+
};
554+
result_expr_id
545555
}
546556
ast::Expr::BinExpr(e) => {
547557
let op = e.op_kind();
@@ -581,7 +591,17 @@ impl ExprCollector<'_> {
581591
}
582592
ArrayExprKind::Repeat { initializer, repeat } => {
583593
let initializer = self.collect_expr_opt(initializer);
584-
let repeat = self.collect_expr_opt(repeat);
594+
let repeat = if let Some(repeat) = repeat {
595+
let (id, prev_owner) =
596+
self.initialize_binding_owner(AstPtr::new(&repeat));
597+
let tmp = self.collect_expr(repeat);
598+
self.body.exprs[id] =
599+
mem::replace(&mut self.body.exprs[tmp], Expr::Missing);
600+
self.current_binding_owner = prev_owner;
601+
id
602+
} else {
603+
self.missing_expr()
604+
};
585605
self.alloc_expr(
586606
Expr::Array(Array::Repeat { initializer, repeat }),
587607
syntax_ptr,
@@ -627,6 +647,16 @@ impl ExprCollector<'_> {
627647
})
628648
}
629649

650+
fn initialize_binding_owner(
651+
&mut self,
652+
syntax_ptr: AstPtr<ast::Expr>,
653+
) -> (ExprId, Option<ExprId>) {
654+
let result_expr_id = self.alloc_expr(Expr::Missing, syntax_ptr);
655+
let prev_binding_owner = self.current_binding_owner.take();
656+
self.current_binding_owner = Some(result_expr_id);
657+
(result_expr_id, prev_binding_owner)
658+
}
659+
630660
/// Desugar `try { <stmts>; <expr> }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(<expr>) }`,
631661
/// `try { <stmts>; }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(()) }`
632662
/// and save the `<new_label>` to use it as a break target for desugaring of the `?` operator.

crates/hir-def/src/body/pretty.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use std::fmt::{self, Write};
55
use syntax::ast::HasName;
66

77
use crate::{
8-
expr::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement},
8+
expr::{
9+
Array, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Literal, Movability, Statement,
10+
},
911
pretty::{print_generic_args, print_path, print_type_ref},
1012
type_ref::TypeRef,
1113
};
@@ -355,7 +357,7 @@ impl<'a> Printer<'a> {
355357
self.print_expr(*index);
356358
w!(self, "]");
357359
}
358-
Expr::Closure { args, arg_types, ret_type, body, closure_kind } => {
360+
Expr::Closure { args, arg_types, ret_type, body, closure_kind, capture_by } => {
359361
match closure_kind {
360362
ClosureKind::Generator(Movability::Static) => {
361363
w!(self, "static ");
@@ -365,6 +367,12 @@ impl<'a> Printer<'a> {
365367
}
366368
_ => (),
367369
}
370+
match capture_by {
371+
CaptureBy::Value => {
372+
w!(self, "move ");
373+
}
374+
CaptureBy::Ref => (),
375+
}
368376
w!(self, "|");
369377
for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() {
370378
if i != 0 {

crates/hir-def/src/expr.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ pub enum Expr {
233233
ret_type: Option<Interned<TypeRef>>,
234234
body: ExprId,
235235
closure_kind: ClosureKind,
236+
capture_by: CaptureBy,
236237
},
237238
Tuple {
238239
exprs: Box<[ExprId]>,
@@ -250,6 +251,14 @@ pub enum ClosureKind {
250251
Async,
251252
}
252253

254+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
255+
pub enum CaptureBy {
256+
/// `move |x| y + x`.
257+
Value,
258+
/// `move` keyword was not specified.
259+
Ref,
260+
}
261+
253262
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
254263
pub enum Movability {
255264
Static,
@@ -442,6 +451,22 @@ pub struct Binding {
442451
pub name: Name,
443452
pub mode: BindingAnnotation,
444453
pub definitions: SmallVec<[PatId; 1]>,
454+
/// Id of the closure/generator that owns this binding. If it is owned by the
455+
/// top level expression, this field would be `None`.
456+
pub owner: Option<ExprId>,
457+
}
458+
459+
impl Binding {
460+
pub fn is_upvar(&self, relative_to: ExprId) -> bool {
461+
match self.owner {
462+
Some(x) => {
463+
// We assign expression ids in a way that outer closures will recieve
464+
// a lower id
465+
x.into_raw() < relative_to.into_raw()
466+
}
467+
None => true,
468+
}
469+
}
445470
}
446471

447472
#[derive(Debug, Clone, Eq, PartialEq)]

crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct Foo;
1616
#[derive(Copy)]
1717
struct Foo;
1818
19-
impl < > core::marker::Copy for Foo< > {}"#]],
19+
impl < > core::marker::Copy for Foo< > where {}"#]],
2020
);
2121
}
2222

@@ -41,7 +41,7 @@ macro Copy {}
4141
#[derive(Copy)]
4242
struct Foo;
4343
44-
impl < > crate ::marker::Copy for Foo< > {}"#]],
44+
impl < > crate ::marker::Copy for Foo< > where {}"#]],
4545
);
4646
}
4747

@@ -57,7 +57,7 @@ struct Foo<A, B>;
5757
#[derive(Copy)]
5858
struct Foo<A, B>;
5959
60-
impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
60+
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
6161
);
6262
}
6363

@@ -74,7 +74,7 @@ struct Foo<A, B, 'a, 'b>;
7474
#[derive(Copy)]
7575
struct Foo<A, B, 'a, 'b>;
7676
77-
impl <T0: core::marker::Copy, T1: core::marker::Copy, > core::marker::Copy for Foo<T0, T1, > {}"#]],
77+
impl <A: core::marker::Copy, B: core::marker::Copy, > core::marker::Copy for Foo<A, B, > where {}"#]],
7878
);
7979
}
8080

@@ -90,7 +90,7 @@ struct Foo<A, B>;
9090
#[derive(Clone)]
9191
struct Foo<A, B>;
9292
93-
impl <T0: core::clone::Clone, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
93+
impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Foo<A, B, > where {}"#]],
9494
);
9595
}
9696

@@ -106,6 +106,6 @@ struct Foo<const X: usize, T>(u32);
106106
#[derive(Clone)]
107107
struct Foo<const X: usize, T>(u32);
108108
109-
impl <const T0: usize, T1: core::clone::Clone, > core::clone::Clone for Foo<T0, T1, > {}"#]],
109+
impl <const X: usize, T: core::clone::Clone, > core::clone::Clone for Foo<X, T, > where u32: core::clone::Clone, {}"#]],
110110
);
111111
}

crates/hir-expand/src/builtin_derive_macro.rs

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use tracing::debug;
55

66
use crate::tt::{self, TokenId};
77
use syntax::{
8-
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName},
8+
ast::{self, AstNode, HasGenericParams, HasModuleItem, HasName, HasTypeBounds},
99
match_ast,
1010
};
1111

@@ -60,8 +60,11 @@ pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander>
6060

6161
struct BasicAdtInfo {
6262
name: tt::Ident,
63-
/// `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
64-
param_types: Vec<Option<tt::Subtree>>,
63+
/// first field is the name, and
64+
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
65+
/// third fields is where bounds, if any
66+
param_types: Vec<(tt::Subtree, Option<tt::Subtree>, Option<tt::Subtree>)>,
67+
field_types: Vec<tt::Subtree>,
6568
}
6669

6770
fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
@@ -75,17 +78,34 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
7578
ExpandError::Other("no item found".into())
7679
})?;
7780
let node = item.syntax();
78-
let (name, params) = match_ast! {
81+
let (name, params, fields) = match_ast! {
7982
match node {
80-
ast::Struct(it) => (it.name(), it.generic_param_list()),
81-
ast::Enum(it) => (it.name(), it.generic_param_list()),
82-
ast::Union(it) => (it.name(), it.generic_param_list()),
83+
ast::Struct(it) => {
84+
(it.name(), it.generic_param_list(), it.field_list())
85+
},
86+
ast::Enum(it) => (it.name(), it.generic_param_list(), None),
87+
ast::Union(it) => (it.name(), it.generic_param_list(), it.record_field_list().map(|x| ast::FieldList::RecordFieldList(x))),
8388
_ => {
8489
debug!("unexpected node is {:?}", node);
8590
return Err(ExpandError::Other("expected struct, enum or union".into()))
8691
},
8792
}
8893
};
94+
let field_types = match fields {
95+
Some(fields) => match fields {
96+
ast::FieldList::RecordFieldList(x) => x
97+
.fields()
98+
.filter_map(|x| x.ty())
99+
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
100+
.collect(),
101+
ast::FieldList::TupleFieldList(x) => x
102+
.fields()
103+
.filter_map(|x| x.ty())
104+
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
105+
.collect(),
106+
},
107+
None => vec![],
108+
};
89109
let name = name.ok_or_else(|| {
90110
debug!("parsed item has no name");
91111
ExpandError::Other("missing name".into())
@@ -97,35 +117,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
97117
.into_iter()
98118
.flat_map(|param_list| param_list.type_or_const_params())
99119
.map(|param| {
100-
if let ast::TypeOrConstParam::Const(param) = param {
120+
let name = param
121+
.name()
122+
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
123+
.unwrap_or_else(tt::Subtree::empty);
124+
let bounds = match &param {
125+
ast::TypeOrConstParam::Type(x) => {
126+
x.type_bound_list().map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
127+
}
128+
ast::TypeOrConstParam::Const(_) => None,
129+
};
130+
let ty = if let ast::TypeOrConstParam::Const(param) = param {
101131
let ty = param
102132
.ty()
103133
.map(|ty| mbe::syntax_node_to_token_tree(ty.syntax()).0)
104134
.unwrap_or_else(tt::Subtree::empty);
105135
Some(ty)
106136
} else {
107137
None
108-
}
138+
};
139+
(name, ty, bounds)
109140
})
110141
.collect();
111-
Ok(BasicAdtInfo { name: name_token, param_types })
142+
Ok(BasicAdtInfo { name: name_token, param_types, field_types })
112143
}
113144

114145
fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResult<tt::Subtree> {
115146
let info = match parse_adt(tt) {
116147
Ok(info) => info,
117148
Err(e) => return ExpandResult::with_err(tt::Subtree::empty(), e),
118149
};
150+
let mut where_block = vec![];
119151
let (params, args): (Vec<_>, Vec<_>) = info
120152
.param_types
121153
.into_iter()
122-
.enumerate()
123-
.map(|(idx, param_ty)| {
124-
let ident = tt::Leaf::Ident(tt::Ident {
125-
span: tt::TokenId::unspecified(),
126-
text: format!("T{idx}").into(),
127-
});
154+
.map(|(ident, param_ty, bound)| {
128155
let ident_ = ident.clone();
156+
if let Some(b) = bound {
157+
let ident = ident.clone();
158+
where_block.push(quote! { #ident : #b , });
159+
}
129160
if let Some(ty) = param_ty {
130161
(quote! { const #ident : #ty , }, quote! { #ident_ , })
131162
} else {
@@ -134,9 +165,16 @@ fn expand_simple_derive(tt: &tt::Subtree, trait_path: tt::Subtree) -> ExpandResu
134165
}
135166
})
136167
.unzip();
168+
169+
where_block.extend(info.field_types.iter().map(|x| {
170+
let x = x.clone();
171+
let bound = trait_path.clone();
172+
quote! { #x : #bound , }
173+
}));
174+
137175
let name = info.name;
138176
let expanded = quote! {
139-
impl < ##params > #trait_path for #name < ##args > {}
177+
impl < ##params > #trait_path for #name < ##args > where ##where_block {}
140178
};
141179
ExpandResult::ok(expanded)
142180
}

0 commit comments

Comments
 (0)