From 5a1545e7add4c0826ef39ce2b91e505ad4dc9b84 Mon Sep 17 00:00:00 2001 From: Travis Hance Date: Fri, 11 Oct 2024 21:32:09 -0400 Subject: [PATCH] returns (#1283) --- dependencies/syn/src/gen/clone.rs | 10 + dependencies/syn/src/gen/debug.rs | 10 + dependencies/syn/src/gen/eq.rs | 12 +- dependencies/syn/src/gen/fold.rs | 13 + dependencies/syn/src/gen/hash.rs | 10 + dependencies/syn/src/gen/visit.rs | 13 + dependencies/syn/src/gen/visit_mut.rs | 13 + dependencies/syn/src/item.rs | 5 + dependencies/syn/src/lib.rs | 2 +- dependencies/syn/src/token.rs | 2 + dependencies/syn/src/verus.rs | 38 ++ dependencies/syn/syn.json | 20 + dependencies/syn/tests/debug/gen.rs | 24 + source/builtin/src/lib.rs | 8 + source/builtin_macros/src/syntax.rs | 18 +- source/rust_verify/src/fn_call_to_vir.rs | 17 +- source/rust_verify/src/rust_to_vir_func.rs | 14 +- source/rust_verify/src/verus_items.rs | 2 + .../tests/external_fn_specification.rs | 31 ++ source/rust_verify_test/tests/fndef_types.rs | 269 +++++++++++ .../tests/returns_postcondition.rs | 445 ++++++++++++++++++ source/vir/src/ast.rs | 4 + source/vir/src/ast_simplify.rs | 19 + source/vir/src/ast_visitor.rs | 12 + source/vir/src/headers.rs | 11 + source/vir/src/modes.rs | 6 + source/vir/src/traits.rs | 24 +- source/vir/src/well_formed.rs | 41 +- 28 files changed, 1078 insertions(+), 15 deletions(-) create mode 100644 source/rust_verify_test/tests/returns_postcondition.rs diff --git a/dependencies/syn/src/gen/clone.rs b/dependencies/syn/src/gen/clone.rs index bdb4cfe800..63b026a650 100644 --- a/dependencies/syn/src/gen/clone.rs +++ b/dependencies/syn/src/gen/clone.rs @@ -2244,6 +2244,15 @@ impl Clone for ReturnType { } } #[cfg_attr(doc_cfg, doc(cfg(feature = "clone-impls")))] +impl Clone for Returns { + fn clone(&self) -> Self { + Returns { + token: self.token.clone(), + exprs: self.exprs.clone(), + } + } +} +#[cfg_attr(doc_cfg, doc(cfg(feature = "clone-impls")))] impl Clone for RevealHide { fn clone(&self) -> Self { RevealHide { @@ -2280,6 +2289,7 @@ impl Clone for Signature { requires: self.requires.clone(), recommends: self.recommends.clone(), ensures: self.ensures.clone(), + returns: self.returns.clone(), decreases: self.decreases.clone(), invariants: self.invariants.clone(), unwind: self.unwind.clone(), diff --git a/dependencies/syn/src/gen/debug.rs b/dependencies/syn/src/gen/debug.rs index b679409a7b..b693fd5981 100644 --- a/dependencies/syn/src/gen/debug.rs +++ b/dependencies/syn/src/gen/debug.rs @@ -3066,6 +3066,15 @@ impl Debug for ReturnType { } } #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] +impl Debug for Returns { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let mut formatter = formatter.debug_struct("Returns"); + formatter.field("token", &self.token); + formatter.field("exprs", &self.exprs); + formatter.finish() + } +} +#[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] impl Debug for RevealHide { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { let mut formatter = formatter.debug_struct("RevealHide"); @@ -3102,6 +3111,7 @@ impl Debug for Signature { formatter.field("requires", &self.requires); formatter.field("recommends", &self.recommends); formatter.field("ensures", &self.ensures); + formatter.field("returns", &self.returns); formatter.field("decreases", &self.decreases); formatter.field("invariants", &self.invariants); formatter.field("unwind", &self.unwind); diff --git a/dependencies/syn/src/gen/eq.rs b/dependencies/syn/src/gen/eq.rs index 40f5322a45..41ee2ee1bf 100644 --- a/dependencies/syn/src/gen/eq.rs +++ b/dependencies/syn/src/gen/eq.rs @@ -2152,6 +2152,14 @@ impl PartialEq for ReturnType { } } #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] +impl Eq for Returns {} +#[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] +impl PartialEq for Returns { + fn eq(&self, other: &Self) -> bool { + self.exprs == other.exprs + } +} +#[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] impl Eq for RevealHide {} #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] impl PartialEq for RevealHide { @@ -2177,8 +2185,8 @@ impl PartialEq for Signature { && self.variadic == other.variadic && self.output == other.output && self.prover == other.prover && self.requires == other.requires && self.recommends == other.recommends && self.ensures == other.ensures - && self.decreases == other.decreases && self.invariants == other.invariants - && self.unwind == other.unwind + && self.returns == other.returns && self.decreases == other.decreases + && self.invariants == other.invariants && self.unwind == other.unwind } } #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] diff --git a/dependencies/syn/src/gen/fold.rs b/dependencies/syn/src/gen/fold.rs index b317fa0a0d..3a1b402d1e 100644 --- a/dependencies/syn/src/gen/fold.rs +++ b/dependencies/syn/src/gen/fold.rs @@ -750,6 +750,9 @@ pub trait Fold { fn fold_return_type(&mut self, i: ReturnType) -> ReturnType { fold_return_type(self, i) } + fn fold_returns(&mut self, i: Returns) -> Returns { + fold_returns(self, i) + } fn fold_reveal_hide(&mut self, i: RevealHide) -> RevealHide { fold_reveal_hide(self, i) } @@ -3566,6 +3569,15 @@ where } } } +pub fn fold_returns(f: &mut F, node: Returns) -> Returns +where + F: Fold + ?Sized, +{ + Returns { + token: Token![returns](tokens_helper(f, &node.token.span)), + exprs: f.fold_specification(node.exprs), + } +} pub fn fold_reveal_hide(f: &mut F, node: RevealHide) -> RevealHide where F: Fold + ?Sized, @@ -3616,6 +3628,7 @@ where requires: (node.requires).map(|it| f.fold_requires(it)), recommends: (node.recommends).map(|it| f.fold_recommends(it)), ensures: (node.ensures).map(|it| f.fold_ensures(it)), + returns: (node.returns).map(|it| f.fold_returns(it)), decreases: (node.decreases).map(|it| f.fold_signature_decreases(it)), invariants: (node.invariants).map(|it| f.fold_signature_invariants(it)), unwind: (node.unwind).map(|it| f.fold_signature_unwind(it)), diff --git a/dependencies/syn/src/gen/hash.rs b/dependencies/syn/src/gen/hash.rs index 5c5d618a9b..3ed415106b 100644 --- a/dependencies/syn/src/gen/hash.rs +++ b/dependencies/syn/src/gen/hash.rs @@ -2873,6 +2873,15 @@ impl Hash for ReturnType { } } #[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] +impl Hash for Returns { + fn hash(&self, state: &mut H) + where + H: Hasher, + { + self.exprs.hash(state); + } +} +#[cfg_attr(doc_cfg, doc(cfg(feature = "extra-traits")))] impl Hash for RevealHide { fn hash(&self, state: &mut H) where @@ -2909,6 +2918,7 @@ impl Hash for Signature { self.requires.hash(state); self.recommends.hash(state); self.ensures.hash(state); + self.returns.hash(state); self.decreases.hash(state); self.invariants.hash(state); self.unwind.hash(state); diff --git a/dependencies/syn/src/gen/visit.rs b/dependencies/syn/src/gen/visit.rs index e71c6a140a..cad7f19d21 100644 --- a/dependencies/syn/src/gen/visit.rs +++ b/dependencies/syn/src/gen/visit.rs @@ -734,6 +734,9 @@ pub trait Visit<'ast> { fn visit_return_type(&mut self, i: &'ast ReturnType) { visit_return_type(self, i); } + fn visit_returns(&mut self, i: &'ast Returns) { + visit_returns(self, i); + } fn visit_reveal_hide(&mut self, i: &'ast RevealHide) { visit_reveal_hide(self, i); } @@ -3958,6 +3961,13 @@ where } } } +pub fn visit_returns<'ast, V>(v: &mut V, node: &'ast Returns) +where + V: Visit<'ast> + ?Sized, +{ + tokens_helper(v, &node.token.span); + v.visit_specification(&node.exprs); +} pub fn visit_reveal_hide<'ast, V>(v: &mut V, node: &'ast RevealHide) where V: Visit<'ast> + ?Sized, @@ -4032,6 +4042,9 @@ where if let Some(it) = &node.ensures { v.visit_ensures(it); } + if let Some(it) = &node.returns { + v.visit_returns(it); + } if let Some(it) = &node.decreases { v.visit_signature_decreases(it); } diff --git a/dependencies/syn/src/gen/visit_mut.rs b/dependencies/syn/src/gen/visit_mut.rs index 4a886a05cf..95ef1e9d80 100644 --- a/dependencies/syn/src/gen/visit_mut.rs +++ b/dependencies/syn/src/gen/visit_mut.rs @@ -735,6 +735,9 @@ pub trait VisitMut { fn visit_return_type_mut(&mut self, i: &mut ReturnType) { visit_return_type_mut(self, i); } + fn visit_returns_mut(&mut self, i: &mut Returns) { + visit_returns_mut(self, i); + } fn visit_reveal_hide_mut(&mut self, i: &mut RevealHide) { visit_reveal_hide_mut(self, i); } @@ -3952,6 +3955,13 @@ where } } } +pub fn visit_returns_mut(v: &mut V, node: &mut Returns) +where + V: VisitMut + ?Sized, +{ + tokens_helper(v, &mut node.token.span); + v.visit_specification_mut(&mut node.exprs); +} pub fn visit_reveal_hide_mut(v: &mut V, node: &mut RevealHide) where V: VisitMut + ?Sized, @@ -4026,6 +4036,9 @@ where if let Some(it) = &mut node.ensures { v.visit_ensures_mut(it); } + if let Some(it) = &mut node.returns { + v.visit_returns_mut(it); + } if let Some(it) = &mut node.decreases { v.visit_signature_decreases_mut(it); } diff --git a/dependencies/syn/src/item.rs b/dependencies/syn/src/item.rs index 835540bca1..2bb95b56f5 100644 --- a/dependencies/syn/src/item.rs +++ b/dependencies/syn/src/item.rs @@ -947,6 +947,7 @@ ast_struct! { pub requires: Option, pub recommends: Option, pub ensures: Option, + pub returns: Option, pub decreases: Option, pub invariants: Option, pub unwind: Option, @@ -976,6 +977,7 @@ impl Signature { self.requires = None; self.recommends = None; self.ensures = None; + self.returns = None; self.decreases = None; self.invariants = None; self.unwind = None; @@ -1725,6 +1727,7 @@ pub mod parsing { let requires: Option = input.parse()?; let recommends: Option = input.parse()?; let ensures: Option = input.parse()?; + let returns: Option = input.parse()?; let decreases: Option = input.parse()?; let invariants: Option = input.parse()?; let unwind: Option = input.parse()?; @@ -1748,6 +1751,7 @@ pub mod parsing { requires, recommends, ensures, + returns, decreases, invariants, unwind, @@ -3591,6 +3595,7 @@ mod printing { self.requires.to_tokens(tokens); self.recommends.to_tokens(tokens); self.ensures.to_tokens(tokens); + self.returns.to_tokens(tokens); self.decreases.to_tokens(tokens); } } diff --git a/dependencies/syn/src/lib.rs b/dependencies/syn/src/lib.rs index 0cbc678cd5..26d7c6c931 100644 --- a/dependencies/syn/src/lib.rs +++ b/dependencies/syn/src/lib.rs @@ -462,7 +462,7 @@ pub use crate::verus::{ GlobalSizeOf, Invariant, InvariantEnsures, InvariantExceptBreak, InvariantNameSet, InvariantNameSetAny, InvariantNameSetList, InvariantNameSetNone, ItemBroadcastGroup, MatchesOpExpr, MatchesOpToken, Mode, ModeExec, ModeGhost, ModeProof, ModeSpec, ModeSpecChecked, - ModeTracked, Open, OpenRestricted, Publish, Recommends, Requires, RevealHide, + ModeTracked, Open, OpenRestricted, Publish, Recommends, Requires, Returns, RevealHide, SignatureDecreases, SignatureInvariants, SignatureUnwind, Specification, TypeFnSpec, View, }; diff --git a/dependencies/syn/src/token.rs b/dependencies/syn/src/token.rs index de84dd602f..e5386dd7d1 100644 --- a/dependencies/syn/src/token.rs +++ b/dependencies/syn/src/token.rs @@ -714,6 +714,7 @@ define_keywords! { "requires" pub struct Requires /// `requires` "recommends" pub struct Recommends /// `recommends` "ensures" pub struct Ensures /// `ensures` + "returns" pub struct Returns /// `returns` "decreases" pub struct Decreases /// `decreases` "opens_invariants" pub struct OpensInvariants /// `opens_invariants` "invariant_except_break" pub struct InvariantExceptBreak /// `invariant_except_break` @@ -938,6 +939,7 @@ macro_rules! export_token_macro { [requires] => { $crate::token::Requires }; [recommends] => { $crate::token::Recommends }; [ensures] => { $crate::token::Ensures }; + [returns] => { $crate::token::Returns }; [decreases] => { $crate::token::Decreases }; [opens_invariants] => { $crate::token::OpensInvariants }; [invariant_except_break] => { $crate::token::InvariantExceptBreak }; diff --git a/dependencies/syn/src/verus.rs b/dependencies/syn/src/verus.rs index b9943f8730..6ae0168561 100644 --- a/dependencies/syn/src/verus.rs +++ b/dependencies/syn/src/verus.rs @@ -126,6 +126,13 @@ ast_struct! { } } +ast_struct! { + pub struct Returns { + pub token: Token![returns], + pub exprs: Specification, + } +} + ast_struct! { pub struct InvariantExceptBreak { pub token: Token![invariant_except_break], @@ -505,6 +512,7 @@ pub mod parsing { || input.peek(Token![invariant]) || input.peek(Token![invariant_ensures]) || input.peek(Token![ensures]) + || input.peek(Token![returns]) || input.peek(Token![decreases]) || input.peek(Token![via]) || input.peek(Token![when]) @@ -586,6 +594,17 @@ pub mod parsing { } } + #[cfg_attr(doc_cfg, doc(cfg(feature = "parsing")))] + impl Parse for Returns { + fn parse(input: ParseStream) -> Result { + let token = input.parse()?; + Ok(Returns { + token, + exprs: input.parse()?, + }) + } + } + #[cfg_attr(doc_cfg, doc(cfg(feature = "parsing")))] impl Parse for InvariantExceptBreak { fn parse(input: ParseStream) -> Result { @@ -772,6 +791,17 @@ pub mod parsing { } } + #[cfg_attr(doc_cfg, doc(cfg(feature = "parsing")))] + impl Parse for Option { + fn parse(input: ParseStream) -> Result { + if input.peek(Token![returns]) { + input.parse().map(Some) + } else { + Ok(None) + } + } + } + #[cfg_attr(doc_cfg, doc(cfg(feature = "parsing")))] impl Parse for Option { fn parse(input: ParseStream) -> Result { @@ -1262,6 +1292,14 @@ mod printing { } } + #[cfg_attr(doc_cfg, doc(cfg(feature = "printing")))] + impl ToTokens for Returns { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.token.to_tokens(tokens); + self.exprs.to_tokens(tokens); + } + } + #[cfg_attr(doc_cfg, doc(cfg(feature = "printing")))] impl ToTokens for InvariantExceptBreak { fn to_tokens(&self, tokens: &mut TokenStream) { diff --git a/dependencies/syn/syn.json b/dependencies/syn/syn.json index d8a56d10d6..9cc5937850 100644 --- a/dependencies/syn/syn.json +++ b/dependencies/syn/syn.json @@ -5611,6 +5611,20 @@ ] } }, + { + "ident": "Returns", + "features": { + "any": [] + }, + "fields": { + "token": { + "token": "Returns" + }, + "exprs": { + "syn": "Specification" + } + } + }, { "ident": "RevealHide", "features": { @@ -5758,6 +5772,11 @@ "syn": "Ensures" } }, + "returns": { + "option": { + "syn": "Returns" + } + }, "decreases": { "option": { "syn": "SignatureDecreases" @@ -7116,6 +7135,7 @@ "RemEq": "%=", "Requires": "requires", "Return": "return", + "Returns": "returns", "Reveal": "reveal", "RevealWithFuel": "reveal_with_fuel", "SelfType": "Self", diff --git a/dependencies/syn/tests/debug/gen.rs b/dependencies/syn/tests/debug/gen.rs index 0965d59bd1..c080580b82 100644 --- a/dependencies/syn/tests/debug/gen.rs +++ b/dependencies/syn/tests/debug/gen.rs @@ -6118,6 +6118,14 @@ impl Debug for Lite { } } } +impl Debug for Lite { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + let _val = &self.value; + let mut formatter = formatter.debug_struct("Returns"); + formatter.field("exprs", Lite(&_val.exprs)); + formatter.finish() + } +} impl Debug for Lite { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { let _val = &self.value; @@ -6337,6 +6345,22 @@ impl Debug for Lite { } formatter.field("ensures", Print::ref_cast(val)); } + if let Some(val) = &_val.returns { + #[derive(RefCast)] + #[repr(transparent)] + struct Print(syn::Returns); + impl Debug for Print { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("Some")?; + let _val = &self.0; + formatter.write_str("(")?; + Debug::fmt(Lite(_val), formatter)?; + formatter.write_str(")")?; + Ok(()) + } + } + formatter.field("returns", Print::ref_cast(val)); + } if let Some(val) = &_val.decreases { #[derive(RefCast)] #[repr(transparent)] diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index 416d1d07f1..f126a3f0bd 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -45,6 +45,14 @@ pub fn ensures(_a: A) { unimplemented!(); } +// Can only appear at beginning of function body +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::builtin::returns"] +#[verifier::proof] +pub fn returns(_a: A) { + unimplemented!(); +} + // Can only appear at beginning of spec function body #[cfg(verus_keep_ghost)] #[rustc_diagnostic_item = "verus::builtin::recommends"] diff --git a/source/builtin_macros/src/syntax.rs b/source/builtin_macros/src/syntax.rs index b07249b844..0590ef939a 100644 --- a/source/builtin_macros/src/syntax.rs +++ b/source/builtin_macros/src/syntax.rs @@ -29,8 +29,8 @@ use syn_verus::{ InvariantNameSetList, Item, ItemBroadcastGroup, ItemConst, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStatic, ItemStruct, ItemTrait, ItemUnion, Lit, Local, MatchesOpExpr, MatchesOpToken, ModeSpec, ModeSpecChecked, Pat, Path, PathArguments, PathSegment, Publish, Recommends, - Requires, ReturnType, Signature, SignatureDecreases, SignatureInvariants, SignatureUnwind, - Stmt, Token, TraitItem, TraitItemMethod, Type, TypeFnSpec, UnOp, Visibility, + Requires, ReturnType, Returns, Signature, SignatureDecreases, SignatureInvariants, + SignatureUnwind, Stmt, Token, TraitItem, TraitItemMethod, Type, TypeFnSpec, UnOp, Visibility, }; const VERUS_SPEC: &str = "VERUS_SPEC__"; @@ -415,6 +415,7 @@ impl Visitor { let requires = self.take_ghost(&mut sig.requires); let recommends = self.take_ghost(&mut sig.recommends); let ensures = self.take_ghost(&mut sig.ensures); + let returns = self.take_ghost(&mut sig.returns); let decreases = self.take_ghost(&mut sig.decreases); let opens_invariants = self.take_ghost(&mut sig.invariants); let unwind = self.take_ghost(&mut sig.unwind); @@ -517,6 +518,19 @@ impl Visitor { } } } + if let Some(Returns { token, mut exprs }) = returns { + if exprs.exprs.len() > 0 { + for expr in exprs.exprs.iter_mut() { + self.visit_expr_mut(expr); + } + stmts.push(Stmt::Semi( + Expr::Verbatim( + quote_spanned_builtin!(builtin, token.span => #builtin::returns([#exprs])), + ), + Semi { spans: [token.span] }, + )); + } + } if let Some(SignatureDecreases { decreases: Decreases { token, mut exprs }, when, via }) = decreases { diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 57f1bfdf4a..70e1006067 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -65,6 +65,7 @@ pub(crate) fn fn_call_to_vir<'tcx>( SpecItem::Requires | SpecItem::Recommends | SpecItem::Ensures + | SpecItem::Returns | SpecItem::OpensInvariantsNone | SpecItem::OpensInvariantsAny | SpecItem::OpensInvariants @@ -302,7 +303,10 @@ fn verus_item_to_vir<'tcx, 'a>( record_spec_fn_no_proof_args(bctx, expr); mk_expr(ExprX::Header(Arc::new(HeaderExprX::NoMethodBody))) } - SpecItem::Requires | SpecItem::Recommends | SpecItem::OpensInvariants => { + SpecItem::Requires + | SpecItem::Recommends + | SpecItem::OpensInvariants + | SpecItem::Returns => { record_spec_fn_no_proof_args(bctx, expr); unsupported_err_unless!( args_len == 1, @@ -316,6 +320,13 @@ fn verus_item_to_vir<'tcx, 'a>( let vir_args = vec_map_result(&subargs, |arg| expr_to_vir(&bctx, arg, ExprModifier::REGULAR))?; + if matches!(spec_item, SpecItem::Returns) && subargs.len() != 1 { + return err_span( + expr.span, + "`returns` clause should have exactly 1 expression", + ); + } + for (arg, vir_arg) in subargs.iter().zip(vir_args.iter()) { let typ = vir::ast_util::undecorate_typ(&vir_arg.typ); match spec_item { @@ -337,6 +348,9 @@ fn verus_item_to_vir<'tcx, 'a>( ); } }, + SpecItem::Returns => { + // type is checked in well_formed.rs + } _ => unreachable!(), } } @@ -347,6 +361,7 @@ fn verus_item_to_vir<'tcx, 'a>( SpecItem::OpensInvariants => { Arc::new(HeaderExprX::InvariantOpens(Arc::new(vir_args))) } + SpecItem::Returns => Arc::new(HeaderExprX::Returns(vir_args[0].clone())), _ => unreachable!(), }; mk_expr(ExprX::Header(header)) diff --git a/source/rust_verify/src/rust_to_vir_func.rs b/source/rust_verify/src/rust_to_vir_func.rs index 3bb6c5a4f4..f54ad6ce4f 100644 --- a/source/rust_verify/src/rust_to_vir_func.rs +++ b/source/rust_verify/src/rust_to_vir_func.rs @@ -843,6 +843,9 @@ pub(crate) fn check_item_fn<'tcx>( if mode == Mode::Spec && (header.require.len() + header.ensure.len()) > 0 { return err_span(sig.span, "spec functions cannot have requires/ensures"); } + if mode == Mode::Spec && header.returns.is_some() { + return err_span(sig.span, "spec functions cannot have `returns` clause"); + } if mode != Mode::Spec && header.recommend.len() > 0 { return err_span(sig.span, "non-spec functions cannot have recommends"); } @@ -1071,6 +1074,7 @@ pub(crate) fn check_item_fn<'tcx>( ret, ens_has_return, require: if mode == Mode::Spec { Arc::new(recommend) } else { header.require }, + returns: header.returns, ensure: ensure, decrease: header.decrease, decrease_when: header.decrease_when, @@ -1130,6 +1134,7 @@ fn fix_external_fn_specification_trait_method_decl_typs( ens_has_return, require, ensure, + returns, decrease, decrease_when, decrease_by, @@ -1202,6 +1207,7 @@ fn fix_external_fn_specification_trait_method_decl_typs( unsupported_err_unless!(require.len() == 0, span, "requires clauses"); unsupported_err_unless!(ensure.len() == 0, span, "ensures clauses"); + unsupported_err_unless!(returns.is_some(), span, "returns clauses"); unsupported_err_unless!(decrease.len() == 0, span, "decreases clauses"); unsupported_err_unless!(decrease_when.is_none(), span, "decreases_when clauses"); unsupported_err_unless!(decrease_by.is_none(), span, "decreases_by clauses"); @@ -1225,6 +1231,7 @@ fn fix_external_fn_specification_trait_method_decl_typs( ens_has_return, require, ensure, + returns, decrease, decrease_when, decrease_by, @@ -1652,7 +1659,10 @@ pub(crate) fn check_item_const_or_static<'tcx>( return err_span(span, "consts cannot have requires/recommends"); } if ret_mode == Mode::Spec && header.ensure.len() > 0 { - return err_span(span, "spec functions cannot have ensures"); + return err_span(span, "spec consts cannot have ensures"); + } + if header.returns.is_some() { + return err_span(span, "consts cannot have `returns` clause"); } let ret_name = air_unique_var(RETURN_VALUE); @@ -1699,6 +1709,7 @@ pub(crate) fn check_item_const_or_static<'tcx>( ens_has_return, require: Arc::new(vec![]), ensure, + returns: None, decrease: Arc::new(vec![]), decrease_when: None, decrease_by: None, @@ -1808,6 +1819,7 @@ pub(crate) fn check_foreign_item_fn<'tcx>( ens_has_return, require: Arc::new(vec![]), ensure: Arc::new(vec![]), + returns: None, decrease: Arc::new(vec![]), decrease_when: None, decrease_by: None, diff --git a/source/rust_verify/src/verus_items.rs b/source/rust_verify/src/verus_items.rs index 985eb7dd54..248031f1bc 100644 --- a/source/rust_verify/src/verus_items.rs +++ b/source/rust_verify/src/verus_items.rs @@ -94,6 +94,7 @@ pub(crate) enum SpecItem { Requires, Recommends, Ensures, + Returns, InvariantExceptBreak, Invariant, Decreases, @@ -352,6 +353,7 @@ fn verus_items_map() -> Vec<(&'static str, VerusItem)> { ("verus::builtin::requires", VerusItem::Spec(SpecItem::Requires)), ("verus::builtin::recommends", VerusItem::Spec(SpecItem::Recommends)), ("verus::builtin::ensures", VerusItem::Spec(SpecItem::Ensures)), + ("verus::builtin::returns", VerusItem::Spec(SpecItem::Returns)), ("verus::builtin::invariant_except_break", VerusItem::Spec(SpecItem::InvariantExceptBreak)), ("verus::builtin::invariant", VerusItem::Spec(SpecItem::Invariant)), ("verus::builtin::decreases", VerusItem::Spec(SpecItem::Decreases)), diff --git a/source/rust_verify_test/tests/external_fn_specification.rs b/source/rust_verify_test/tests/external_fn_specification.rs index c17c620f10..081139ed3e 100644 --- a/source/rust_verify_test/tests/external_fn_specification.rs +++ b/source/rust_verify_test/tests/external_fn_specification.rs @@ -1311,3 +1311,34 @@ test_verify_one_file! { } } => Err(err) => assert_vir_error_msg(err, "an item in a trait impl cannot be marked 'external'") } + +test_verify_one_file! { + #[test] test_returns_clause verus_code! { + #[verifier(external)] + fn negate_bool(b: bool, x: u8) -> bool { + !b + } + + #[verifier(external_fn_specification)] + fn negate_bool_requires_ensures(b: bool, x: u8) -> bool + requires x != 0, + returns !b + { + negate_bool(b, x) + } + + fn test1() { + let ret_b = negate_bool(true, 1); + assert(ret_b == false); + } + + fn test2() { + let ret_b = negate_bool(true, 0); // FAILS + } + + fn test3() { + let ret_b = negate_bool(true, 1); + assert(ret_b == true); // FAILS + } + } => Err(err) => assert_fails(err, 2) +} diff --git a/source/rust_verify_test/tests/fndef_types.rs b/source/rust_verify_test/tests/fndef_types.rs index 8859996e7a..39fc565799 100644 --- a/source/rust_verify_test/tests/fndef_types.rs +++ b/source/rust_verify_test/tests/fndef_types.rs @@ -1411,3 +1411,272 @@ test_verify_one_file! { } } => Err(err) => assert_vir_error_msg(err, "using a datatype constructor as a function value") } + +test_verify_one_file_with_options! { + #[test] test_returns_clause ["vstd"] => verus_code! { + fn llama(x: u8) -> bool + requires x == 4 || x == 7, + returns (x == 4) + { + x == 4 + } + + fn test() { + let t = llama; + + let b = t(4); + assert(b); + + let b = t(7); + assert(!b); + + assert(forall |x| (x == 4 || x == 7) ==> call_requires(llama, (x,))); + assert(forall |x, y| call_ensures(llama, (x,), y) ==> x == 4 ==> y); + assert(forall |x, y| call_ensures(llama, (x,), y) ==> x == 7 ==> !y); + } + + fn test2() { + let t = llama; + + let b = t(4); + assert(!b); // FAILS + } + + fn test3() { + let t = llama; + + t(12); // FAILS + } + + fn test4() { + assert(forall |x| (x == 5) ==> call_requires(llama, (x,))); // FAILS + } + + fn test5() { + assert(forall |x, y| call_ensures(llama, (x,), y) ==> x == 4 ==> !y); // FAILS + } + } => Err(err) => assert_fails(err, 4) +} + +test_verify_one_file_with_options! { + #[test] test_returns_clause2 ["vstd"] => verus_code! { + fn llama(x: u8) -> (b: bool) + requires x == 4 || x == 7 || x == 9, + ensures x == 4 || x == 7 + returns (x == 4) + { + if x == 9 { + loop { } + } + x == 4 + } + + fn test() { + let t = llama; + + let b = t(4); + assert(b); + + let b = t(7); + assert(!b); + + assert(forall |x| (x == 4 || x == 7) ==> call_requires(llama, (x,))); + assert(forall |x, y| call_ensures(llama, (x,), y) ==> + (x == 4 || x == 7) && (y == (x == 4))); + } + + fn test2() { + let t = llama; + + let b = t(4); + assert(!b); // FAILS + } + + fn test3() { + let t = llama; + + t(12); // FAILS + } + + fn test4() { + assert(forall |x| (x == 5) ==> call_requires(llama, (x,))); // FAILS + } + + fn test5() { + let t = llama; + + let b = t(9); + assert(false); + } + + fn test6() { + let t = llama; + + let b = t(7); + assert(b); // FAILS + } + + fn test7() { + let t = llama; + + let b = t(7); + assert(!b); + } + } => Err(err) => assert_fails(err, 4) +} + +test_verify_one_file_with_options! { + #[test] call_ensures_return_clause_on_trait_method_decl ["vstd"] => verus_code! { + trait Tr : Sized { + spec fn ens(&self, i: u8, s: &Self) -> bool; + spec fn ret(&self, i: u8) -> Self; + + fn test(&self, i: u8) -> (s: Self) + ensures self.ens(i, &s), + returns self.ret(i); + } + + // ok + + struct X { j: u8 } + + impl Tr for X { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 250 } + spec fn ret(&self, i: u8) -> Self { + X { + j: (self.j + i) as u8 + } + } + + fn test(&self, i: u8) -> (s: Self) + ensures !(20 <= i < 30), + { + if self.j as u64 + i as u64 >= 250 || (20 <= i && i < 30) { + loop { } + } + X { j: self.j + i } + } + } + + // generic + + fn test1(t: T, i: u8, r: T) { + assert(call_ensures(T::test, (&t, i), r) ==> + t.ens(i, &r) && r == t.ret(i)); + } + + fn test1_fail(t: T, i: u8, r: T) { + assert(t.ens(i, &r) && r == t.ret(i) ==> + call_ensures(T::test, (&t, i), r)); // FAILS + } + + // specific + + fn test2(x: X, i: u8, r: X) { + assert(call_ensures(X::test, (&x, i), r) ==> + x.j + i < 250 && r == (X { j: (x.j + i) as u8 }) && !(20 <= i < 30)); + } + + fn test2_fail(x: X, i: u8, r: X) { + assert(x.j + i < 250 && r == (X { j: (x.j + i) as u8 }) && !(20 <= i < 30) + ==> call_ensures(X::test, (&x, i), r)); // FAILS + } + } => Err(err) => assert_fails(err, 2) +} + +test_verify_one_file_with_options! { + #[test] call_ensures_return_clause_on_trait_method_impl ["vstd"] => verus_code! { + trait Tr : Sized { + spec fn ens(&self, i: u8, s: &Self) -> bool; + + fn test(&self, i: u8) -> (s: Self) + ensures self.ens(i, &s); + } + + // ok + + struct X { j: u8 } + + impl Tr for X { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 250 } + + fn test(&self, i: u8) -> (s: Self) + ensures !(20 <= i < 30), + returns (X { j: (self.j + i) as u8 }), + { + if self.j as u64 + i as u64 >= 250 || (20 <= i && i < 30) { + loop { } + } + X { j: self.j + i } + } + } + + // specific + + fn test2(x: X, i: u8, r: X) { + assert(call_ensures(X::test, (&x, i), r) ==> + x.j + i < 250 && r == (X { j: (x.j + i) as u8 }) && !(20 <= i < 30)); + } + + fn test2_fail(x: X, i: u8, r: X) { + assert(x.j + i < 250 && r == (X { j: (x.j + i) as u8 }) && !(20 <= i < 30) + ==> call_ensures(X::test, (&x, i), r)); // FAILS + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] call_ensures_returns_on_default_method_impl verus_code! { + trait Tr : Sized { + fn test(&self, i: u8) -> (s: &Self) + ensures 20 <= i < 30, + returns + self + { + if !(20 <= i && i < 30) { loop { } } + self + } + } + + struct X { } + impl Tr for X { } + + struct Y { } + impl Tr for Y { + fn test(&self, i: u8) -> (s: &Self) + ensures 15 <= i < 25, + { + if !(20 <= i && i < 25) { loop { } } + self + } + } + + // Generic + + fn test_generic(t: T, i: u8, r: T) { + assert(call_ensures(T::test, (&t, i), &r) ==> 20 <= i < 30 && r == t); + } + + fn test_generic_fails(t: T, i: u8, r: T) { + assert(call_ensures(T::test, (&t, i), &r) <== 20 <= i < 30 && r == t); // FAILS + } + + // Specific + + fn test_x(t: X, i: u8, r: X) { + assert(call_ensures(X::test, (&t, i), &r) ==> 20 <= i < 30 && r == t); + } + + fn test_x_fails(t: X, i: u8, r: X) { + assert(call_ensures(X::test, (&t, i), &r) <== 20 <= i < 30 && r == t); // FAILS + } + + fn test_y(t: Y, i: u8, r: Y) { + assert(call_ensures(Y::test, (&t, i), &r) ==> 20 <= i < 35 && r == t); + } + + fn test_y_fails(t: Y, i: u8, r: Y) { + assert(call_ensures(Y::test, (&t, i), &r) <== 20 <= i < 35 && r == t); // FAILS + } + } => Err(err) => assert_fails(err, 3) +} diff --git a/source/rust_verify_test/tests/returns_postcondition.rs b/source/rust_verify_test/tests/returns_postcondition.rs new file mode 100644 index 0000000000..08f6ee4788 --- /dev/null +++ b/source/rust_verify_test/tests/returns_postcondition.rs @@ -0,0 +1,445 @@ +#![feature(rustc_private)] +#[macro_use] +mod common; +use common::*; + +test_verify_one_file! { + #[test] returns_basic verus_code! { + fn test() -> u8 + returns 20u8, + { + 20u8 + } + + proof fn proof_test() -> u8 + returns 20u8, + { + 20u8 + } + + + fn test2() { + let j = test(); + assert(j == 20); + } + + fn test3() -> u8 + returns 20u8, // FAILS + { + 19u8 + } + + fn test4() -> u8 + returns 20u8, + { + return 19u8; // FAILS + } + + fn test5(a: u8, b: u8) -> (k: u8) + requires a + b < 256, + ensures a + b < 257, + returns (a + b) as u8, + { + return a; // FAILS + } + + fn test6(a: u8, b: u8) -> (k: u8) + requires a + b < 256, + ensures a + b < 250, + returns (a + b) as u8, + { + return a + b; // FAILS + } + + proof fn proof_test5(a: u8, b: u8) -> (k: u8) + requires a + b < 256, + ensures a + b < 257, + returns (a + b) as u8, + { + return a; // FAILS + } + + proof fn proof_test6(a: u8, b: u8) -> (k: u8) + requires a + b < 256, + ensures a + b < 250, + returns (a + b) as u8, + { + return (a + b) as u8; // FAILS + } + + } => Err(err) => assert_fails(err, 6) +} + +test_verify_one_file! { + #[test] wrong_type verus_code! { + fn test() -> bool + returns 20u8, + { + true + } + } => Err(err) => assert_vir_error_msg(err, "type of `returns` clause does not match function return type") +} + +test_verify_one_file! { + #[test] spec_fn_returns verus_code! { + spec fn f(x: u8) -> bool + returns x == 3 + { + x == 3 + } + } => Err(err) => assert_vir_error_msg(err, "spec functions cannot have `returns` clause") +} + +test_verify_one_file! { + #[test] spec_fn_returns_references_ret_param verus_code! { + fn test() -> (x: u8) + returns x, + { + 20u8 + } + } => Err(err) => assert_rust_error_msg(err, "cannot find value `x` in this scope") +} + +test_verify_one_file! { + #[test] returning_unit_value verus_code! { + fn test(x: u8) + returns (), + { + } + + fn test(x: u8) + ensures false, // FAILS + returns (), + { + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] default_trait_fn_with_returns verus_code! { + trait Tr : Sized { + fn test(&self) -> &Self + returns + self + { + self + } + } + + struct X { } + + impl Tr for X { } + + fn test2(t: T) { + let t2 = t.test(); + assert(t2 == t); + } + + fn test3(t: T) { + let t2 = t.test(); + assert(t2 == t); + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] default_trait_fn_with_returns_override verus_code! { + trait Tr : Sized { + fn test(&self) -> &Self + returns + self + { + self + } + } + + struct X { i: u8 } + + impl Tr for X { + fn test(&self) -> (some_ret_value: &Self) + ensures self.i == 5 + { + if self.i != 5 { loop { } } + self + } + } + + fn test2(x: X) { + let x2 = x.test(); + assert(x2 == x); + assert(x.i == 5); + } + + fn test3(x: X) { + let x2 = x.test(); + assert(x2 == x); + assert(x.i == 5); + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] default_trait_fn_with_returns_conflict verus_code! { + trait Tr : Sized { + fn test(&self) -> &Self + returns + self + { + self + } + } + + struct X { i: u8 } + + impl Tr for X { + fn test(&self) -> &Self + returns &(X { i: 0 }) + { + &X { i: 0 } + } + } + } => Err(err) => assert_vir_error_msg(err, "a `returns` clause cannot be declared on both a trait method impl and its declaration") +} + +test_verify_one_file! { + #[test] default_trait_fn_with_returns2 verus_code! { + trait Tr : Sized { + fn some_other(&self) -> &Self; + + fn test(&self) -> &Self + returns + self + { + return some_other(self); // FAILS + } + } + } => Err(err) => assert_fails(err, 1) +} + +test_verify_one_file! { + #[test] trait_returns_conflict verus_code! { + trait Tr : Sized { + spec fn ens(&self, i: u8, s: &Self) -> bool; + spec fn ret(&self, i: u8) -> Self; + + fn test(&self, i: u8) -> (s: Self) + ensures self.ens(i, &s), + returns self.ret(i); + } + + struct U { j: u8 } + + impl Tr for U { + spec fn ens(&self, i: u8, s: &Self) -> bool { true } + spec fn ret(&self, i: u8) -> Self { + U { + j: i + } + } + + fn test(&self, i: u8) -> (s: Self) + returns (U { j: 0 }), + { + return U { j: 0 }; // FAILS + } + } + } => Err(err) => assert_vir_error_msg(err, "a `returns` clause cannot be declared on both a trait method impl and its declaration") +} + +test_verify_one_file! { + #[test] trait_returns_on_trait_method_decl verus_code! { + trait Tr : Sized { + spec fn ens(&self, i: u8, s: &Self) -> bool; + spec fn ret(&self, i: u8) -> Self; + + fn test(&self, i: u8) -> (s: Self) + ensures self.ens(i, &s), + returns self.ret(i); + } + + // ok + + struct X { j: u8 } + + impl Tr for X { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 250 } + spec fn ret(&self, i: u8) -> Self { + X { + j: (self.j + i) as u8 + } + } + + fn test(&self, i: u8) -> (s: Self) { + if self.j as u64 + i as u64 >= 250 { + loop { } + } + X { j: self.j + i } + } + } + + // fail inherited ensures + + struct Y { j: u8 } + + impl Tr for Y { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 256 } + spec fn ret(&self, i: u8) -> Self { + Y { + j: self.j + } + } + + fn test(&self, i: u8) -> (s: Self) { + return Y { j: self.j }; // FAILS + } + } + + // fail inherited returns + + struct Z { j: u8 } + + impl Tr for Z { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 256 } + spec fn ret(&self, i: u8) -> Self { + Z { + j: (self.j + i) as u8 + } + } + + fn test(&self, i: u8) -> (s: Self) { + if self.j as u64 + i as u64 >= 256 { + loop { } + } + return Z { j: self.j }; // FAILS + } + } + + // fail inherited returns with extra ensures + + struct W { j: u8 } + + impl Tr for W { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 256 } + spec fn ret(&self, i: u8) -> Self { + W { + j: (self.j + i) as u8 + } + } + + fn test(&self, i: u8) -> (s: Self) + ensures s.j == self.j, + { + if self.j as u64 + i as u64 >= 256 { + loop { } + } + return W { j: self.j }; // FAILS + } + } + + // Caller, generic + + fn test_call(t: T, i: u8) { + let c = t.test(i); + assert(t.ens(i, &c)); + assert(c == t.ret(i)); + } + + fn test_call_fail(t: T, i: u8) { + let c = t.test(i); + assert(t.ens(i, &c)); + assert(c == t.ret(i)); + assert(false); // FAILS + } + + // Caller, specific + + fn test_specific_call(x: X, i: u8) { + let c = x.test(i); + assert(x.j + i < 250); + assert(c == X { j: (x.j + i) as u8 }); + } + + fn test_specific_call_fail(x: X, i: u8) { + let c = x.test(i); + assert(x.j + i < 250); + assert(c == X { j: (x.j + i) as u8 }); + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 5) +} + +test_verify_one_file! { + #[test] trait_returns_on_trait_method_impl verus_code! { + trait Tr : Sized { + spec fn ens(&self, i: u8, s: &Self) -> bool; + + fn test(&self, i: u8) -> (s: Self) + ensures self.ens(i, &s); + } + + // ok + + struct X { j: u8 } + + impl Tr for X { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 256 } + + fn test(&self, i: u8) -> (s: Self) + returns (X { j: (self.j + i) as u8 }) + { + if self.j as u64 + i as u64 >= 256 { + loop { } + } + X { j: self.j + i } + } + } + + // fail inherited ensures + + struct Y { j: u8 } + + impl Tr for Y { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 256 } + + fn test(&self, i: u8) -> (s: Self) + returns (Y { j: self.j }) + { + return Y { j: self.j }; // FAILS + } + } + + // fail new returns + + struct Z { j: u8 } + + impl Tr for Z { + spec fn ens(&self, i: u8, s: &Self) -> bool { self.j + i < 250 } + + fn test(&self, i: u8) -> (s: Self) + returns (Z { j: (self.j + i) as u8 }) + { + if self.j as u64 + i as u64 >= 250 { + loop { } + } + return Z { j: self.j }; // FAILS + } + } + + // Caller + + fn test_call(z: Z, i: u8) { + let z2 = z.test(i); + assert(z.j + i < 250); + assert(z2.j == z.j + i); + } + + fn test_call_fail(z: Z, i: u8) { + let z2 = z.test(i); + assert(z.j + i < 250); + assert(z2.j == z.j + i); + assert(false); // FAILS + } + } => Err(err) => assert_fails(err, 3) +} diff --git a/source/vir/src/ast.rs b/source/vir/src/ast.rs index c22c5d4b4a..df9f2a005f 100644 --- a/source/vir/src/ast.rs +++ b/source/vir/src/ast.rs @@ -534,6 +534,8 @@ pub enum HeaderExprX { Requires(Exprs), /// Postconditions on exec/proof functions, with an optional name and type for the return value Ensures(Option<(VarIdent, Typ)>, Exprs), + /// Returns clause + Returns(Expr), /// Recommended preconditions on spec functions, used to help diagnose mistakes in specifications. /// Checking of recommends is disabled by default. Recommends(Exprs), @@ -1074,6 +1076,8 @@ pub struct FunctionX { pub require: Exprs, /// Postconditions (proof/exec functions only) pub ensure: Exprs, + /// Expression in the 'returns' clause + pub returns: Option, /// Decreases clause to ensure recursive function termination /// decrease.len() == 0 means no decreases clause /// decrease.len() >= 1 means list of expressions, interpreted in lexicographic order diff --git a/source/vir/src/ast_simplify.rs b/source/vir/src/ast_simplify.rs index 44c9cf609d..80f52e4288 100644 --- a/source/vir/src/ast_simplify.rs +++ b/source/vir/src/ast_simplify.rs @@ -935,6 +935,24 @@ fn simplify_function( ) -> Result { state.reset_for_function(); let mut functionx = function.x.clone(); + + if let Some(r) = functionx.returns.clone() { + functionx.returns = None; + + if functionx.ens_has_return { + let var = SpannedTyped::new( + &r.span, + &functionx.ret.x.typ, + ExprX::Var(functionx.ret.x.name.clone()), + ); + let eq = mk_eq(&r.span, &var, &r); + Arc::make_mut(&mut functionx.ensure).push(eq); + } else { + // For a unit return type, any returns clause is tautological so we + // can just skip appending to the postconditions. + } + } + let local = LocalCtxt { span: function.span.clone(), typ_params: (*functionx.typ_params).clone() }; @@ -1013,6 +1031,7 @@ fn simplify_function( ); functionx.ret = functionx.ret.new_x(crate::ast::ParamX { name: ret_name, ..functionx.ret.x.clone() }); + Ok(Spanned::new(function.span.clone(), functionx)) } diff --git a/source/vir/src/ast_visitor.rs b/source/vir/src/ast_visitor.rs index be5a5904c5..cfa44e8d1d 100644 --- a/source/vir/src/ast_visitor.rs +++ b/source/vir/src/ast_visitor.rs @@ -618,6 +618,7 @@ where require, ensure, ens_has_return: _, + returns, decrease, decrease_when, decrease_by: _, @@ -650,6 +651,10 @@ where } map.pop_scope(); + if let Some(e) = returns { + expr_visitor_control_flow!(expr_visitor_dfs(e, map, mf)); + } + for e in decrease.iter() { expr_visitor_control_flow!(expr_visitor_dfs(e, map, mf)); } @@ -1205,6 +1210,7 @@ where ens_has_return, require, ensure, + returns, decrease, decrease_when, decrease_by, @@ -1267,6 +1273,11 @@ where Arc::new(vec_map_result(ensure, |e| map_expr_visitor_env(e, map, env, fe, fs, ft))?); map.pop_scope(); + let returns = match returns { + Some(e) => Some(map_expr_visitor_env(e, map, env, fe, fs, ft)?), + None => None, + }; + let decrease = Arc::new(vec_map_result(decrease, |e| map_expr_visitor_env(e, map, env, fe, fs, ft))?); let decrease_when = decrease_when @@ -1329,6 +1340,7 @@ where ens_has_return: *ens_has_return, require, ensure, + returns, decrease, decrease_when, decrease_by, diff --git a/source/vir/src/headers.rs b/source/vir/src/headers.rs index 83c06e443f..3a09f9ff54 100644 --- a/source/vir/src/headers.rs +++ b/source/vir/src/headers.rs @@ -17,6 +17,7 @@ pub struct Header { pub recommend: Exprs, pub ensure_id_typ: Option<(VarIdent, Typ)>, pub ensure: Exprs, + pub returns: Option, pub invariant_except_break: Exprs, pub invariant: Exprs, pub decrease: Exprs, @@ -33,6 +34,7 @@ pub fn read_header_block(block: &mut Vec) -> Result { let mut extra_dependencies: Vec = Vec::new(); let mut require: Option = None; let mut ensure: Option<(Option<(VarIdent, Typ)>, Exprs)> = None; + let mut returns: Option = None; let mut recommend: Option = None; let mut invariant_except_break: Option = None; let mut invariant: Option = None; @@ -88,6 +90,12 @@ pub fn read_header_block(block: &mut Vec) -> Result { } ensure = Some((id_typ.clone(), es.clone())); } + HeaderExprX::Returns(e) => { + if returns.is_some() { + return Err(error(&stmt.span, "only one call to returns allowed")); + } + returns = Some(e.clone()); + } HeaderExprX::InvariantExceptBreak(es) => { if invariant_except_break.is_some() { return Err(error( @@ -206,6 +214,7 @@ pub fn read_header_block(block: &mut Vec) -> Result { recommend, ensure_id_typ, ensure, + returns, invariant_except_break, invariant, decrease, @@ -305,6 +314,7 @@ fn make_trait_decl(method: &Function, spec_method: &Function) -> Result Result Result { ens_has_return: default_function.x.ens_has_return, require: Arc::new(vec![]), ensure: Arc::new(vec![]), + returns: None, decrease: Arc::new(vec![]), decrease_when: None, decrease_by: None, @@ -923,22 +924,35 @@ pub fn merge_external_traits(krate: Krate) -> Result { pub fn fixup_ens_has_return_for_trait_method_impls(krate: Krate) -> Result { let mut krate = krate; let kratex = &mut Arc::make_mut(&mut krate); - let mut fun_map = HashMap::::new(); + let mut fun_map = HashMap::::new(); for function in kratex.functions.iter() { if matches!(function.x.kind, FunctionKind::TraitMethodDecl { .. }) { - fun_map.insert(function.x.name.clone(), function.x.ens_has_return); + fun_map.insert(function.x.name.clone(), function.clone()); } } for function in kratex.functions.iter_mut() { if let FunctionKind::TraitMethodImpl { method, .. } = &function.x.kind { + let method = method.clone(); if !function.x.ens_has_return { - match fun_map.get(method) { + match fun_map.get(&method) { None => {} - Some(true) => { + Some(f) if f.x.ens_has_return => { let functionx = &mut Arc::make_mut(&mut *function).x; functionx.ens_has_return = true; } - Some(false) => {} + Some(_) => {} + } + } + if function.x.returns.is_some() { + match fun_map.get(&method) { + None => {} + Some(f) if f.x.returns.is_some() => { + return Err(error( + &function.span, + "a `returns` clause cannot be declared on both a trait method impl and its declaration", + ).secondary_span(&f.span)); + } + Some(_) => {} } } } diff --git a/source/vir/src/well_formed.rs b/source/vir/src/well_formed.rs index d6a32878cc..fa676b705b 100644 --- a/source/vir/src/well_formed.rs +++ b/source/vir/src/well_formed.rs @@ -5,7 +5,8 @@ use crate::ast::{ }; use crate::ast_util::{ dt_as_friendly_rust_name, fun_as_friendly_rust_name, is_visible_to_opt, - path_as_friendly_rust_name, referenced_vars_expr, + path_as_friendly_rust_name, referenced_vars_expr, typ_to_diagnostic_str, types_equal, + undecorate_typ, }; use crate::datatype_to_air::is_datatype_transparent; use crate::def::user_local_name; @@ -524,6 +525,12 @@ fn check_function( "decreases_by/recommends_by function cannot have ensures clauses", )); } + if function.x.returns.is_some() { + return Err(error( + &function.span, + "decreases_by/recommends_by function cannot have ensures clauses", + )); + } if function.x.ens_has_return { return Err(error( &function.span, @@ -672,7 +679,6 @@ fn check_function( #[cfg(feature = "singular")] if function.x.attrs.integer_ring { - use crate::ast_util::undecorate_typ; let _ = match std::env::var("VERUS_SINGULAR_PATH") { Ok(_) => {} Err(_) => { @@ -767,6 +773,9 @@ fn check_function( Ok(()) })?; } + if let Some(r) = &function.x.returns { + return Err(error(&r.span, "integer_ring should not have a `returns` clause")); + } } if function.x.publish.is_some() && function.x.mode != Mode::Spec { @@ -799,6 +808,26 @@ fn check_function( let disallow_private_access = Some((&function.x.visibility.restricted_to, msg)); check_expr(ctxt, function, ens, disallow_private_access, Place::BodyOrPostState)?; } + if let Some(r) = &function.x.returns { + if !types_equal(&undecorate_typ(&r.typ), &undecorate_typ(&function.x.ret.x.typ)) { + return Err(error( + &r.span, + "type of `returns` clause does not match function return type", + ) + .secondary_label( + &function.span, + format!("this function returns `{}`", typ_to_diagnostic_str(&function.x.ret.x.typ)), + ) + .secondary_label( + &r.span, + format!("the `returns` clause has type `{}`", typ_to_diagnostic_str(&r.typ)), + )); + } + + let msg = "'requires' clause of public function"; + let disallow_private_access = Some((&function.x.visibility.restricted_to, msg)); + check_expr(ctxt, function, r, disallow_private_access, Place::PreState("returns"))?; + } match &function.x.mask_spec { None => {} Some(MaskSpec::InvariantOpens(es) | MaskSpec::InvariantOpensExcept(es)) => { @@ -1246,7 +1275,13 @@ pub fn check_crate( _ => VisitorControlFlow::Recurse, }; let mut found_trigger = false; - for expr in function.x.require.iter().chain(function.x.ensure.iter()) { + for expr in function + .x + .require + .iter() + .chain(function.x.ensure.iter()) + .chain(function.x.returns.iter()) + { let control = crate::ast_visitor::expr_visitor_dfs( expr, &mut air::scope_map::ScopeMap::new(),