-
Notifications
You must be signed in to change notification settings - Fork 750
feat: add cbv_eval attribute
#12296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: add cbv_eval attribute
#12296
Changes from all commits
01978e1
724d08f
44f7912
8625a37
24b7293
a14a3f1
b8e03c6
0450fc6
e30fef1
8bdb14a
a8995a1
bc4df28
bcb40d4
621d4e9
9ab33ab
8188756
8fa7c9d
e076f39
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /- | ||
| Copyright (c) 2026 Lean FRO, LLC. All rights reserved. | ||
| Released under Apache 2.0 license as described in the file LICENSE. | ||
| Authors: Wojciech Różowski | ||
| -/ | ||
| module | ||
| prelude | ||
| public import Lean.Data.NameMap | ||
| public import Lean.ScopedEnvExtension | ||
| public import Lean.Elab.InfoTree | ||
| public import Lean.Meta.Sym.Simp.Theorems | ||
|
|
||
| public section | ||
| namespace Lean.Meta.Sym.Simp | ||
|
|
||
| def Theorem.declName (thm : Theorem) : Name := thm.expr.getAppFn.constName! | ||
|
|
||
| def Theorem.isPrivate (thm : Theorem) : Bool := isPrivateName thm.declName | ||
|
|
||
| end Lean.Meta.Sym.Simp | ||
|
|
||
| namespace Lean.Meta.Tactic.Cbv | ||
| open Lean.Meta.Sym.Simp | ||
|
|
||
| structure CbvEvalEntry where | ||
| appFn : Name | ||
| thm : Theorem | ||
| deriving BEq, Inhabited | ||
|
|
||
| def mkCbvTheoremFromConst (declName : Name) : MetaM CbvEvalEntry := do | ||
| let cinfo ← getConstVal declName | ||
| let us := cinfo.levelParams.map mkLevelParam | ||
| let val := mkConst declName us | ||
| let type ← inferType val | ||
| unless (← isProp type) do throwError "{val} is not a theorem and thus cannot be marked with `cbv_eval` attribute" | ||
| let fnName ← forallTelescope type fun _ body => do | ||
| let some (_, lhs, _) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality" | ||
| let appFn := lhs.getAppFn | ||
| let some constName := appFn.constName? | throwError "The left-hand side of a theorem {val} is not an application of a constant" | ||
| return constName | ||
| let thm ← mkTheoremFromDecl declName | ||
| return ⟨fnName, thm⟩ | ||
|
|
||
| structure CbvEvalState where | ||
| lemmas : NameMap Theorems := {} | ||
| deriving Inhabited | ||
|
|
||
| def CbvEvalState.addEntry (s : CbvEvalState) (e : CbvEvalEntry) : CbvEvalState := | ||
| let existing := (s.lemmas.find? e.appFn).getD {} | ||
| { s with lemmas := s.lemmas.insert e.appFn (existing.insert e.thm) } | ||
|
|
||
| abbrev CbvEvalExtension := SimpleScopedEnvExtension CbvEvalEntry CbvEvalState | ||
|
|
||
| builtin_initialize cbvEvalExt : CbvEvalExtension ← | ||
| registerSimpleScopedEnvExtension { | ||
| name := `cbvEvalExt | ||
| initial := {} | ||
| addEntry := CbvEvalState.addEntry | ||
| exportEntry? := fun level entry => do | ||
| guard (level == .private || !entry.thm.isPrivate) | ||
| return entry | ||
| } | ||
|
|
||
| def getCbvEvalLemmas (target : Name) : CoreM (Option Theorems) := do | ||
| let s := cbvEvalExt.getState (← getEnv) | ||
| return (s.lemmas.find? target) | ||
|
|
||
| syntax (name := Parser.Attr.cbvEval) "cbv_eval" : attr | ||
|
|
||
| builtin_initialize | ||
| registerBuiltinAttribute { | ||
| ref := `cbvEvalAttr | ||
| name := `cbv_eval | ||
| descr := "Register a theorem as a rewrite rule for CBV evaluation of a given definition. \ | ||
| Usage: @[cbv_eval] theorem ..." | ||
| applicationTime := AttributeApplicationTime.afterCompilation | ||
| add := fun lemmaName _ kind => do | ||
| let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName) {} | ||
| cbvEvalExt.add entry kind | ||
| } | ||
|
|
||
| end Lean.Meta.Tactic.Cbv | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ public import Lean.Meta.Sym.Simp.SimpM | |
| public import Lean.Meta.Tactic.Cbv.Opaque | ||
| import Lean.Meta.Tactic.Cbv.Util | ||
| import Lean.Meta.Tactic.Cbv.TheoremsLookup | ||
| import Lean.Meta.Tactic.Cbv.CbvEvalExt | ||
| import Lean.Meta.Sym | ||
|
|
||
| namespace Lean.Meta.Tactic.Cbv | ||
|
|
@@ -69,23 +70,43 @@ def betaReduce : Simproc := fun e => do | |
| let new ← Sym.share new | ||
| return .step new (← Sym.mkEqRefl new) | ||
|
|
||
| def tryCbvTheorems : Simproc := fun e => do | ||
| let some fnName := e.getAppFn.constName? | return .rfl | ||
| let some evalLemmas ← getCbvEvalLemmas fnName | return .rfl | ||
| Theorems.rewrite evalLemmas (d := dischargeNone) e | ||
|
|
||
| def handleApp : Simproc := fun e => do | ||
| unless e.isApp do return .rfl | ||
| let fn := e.getAppFn | ||
| match fn with | ||
| | .const constName _ => | ||
| let info ← getConstInfo constName | ||
| (guardSimproc (fun _ => info.hasValue) handleConstApp) <|> reduceRecMatcher <| e | ||
| tryCbvTheorems <|> (guardSimproc (fun _ => info.hasValue) handleConstApp) <|> reduceRecMatcher <| e | ||
| | .lam .. => betaReduce e | ||
| | _ => return .rfl | ||
|
|
||
| def isOpaqueApp : Simproc := fun e => do | ||
| let some fnName := e.getAppFn.constName? | return .rfl | ||
| return .rfl (← isCbvOpaque fnName) | ||
| let hasTheorems := (← getCbvEvalLemmas fnName).isSome | ||
| if hasTheorems then | ||
| let res ← (simpAppArgs >> tryCbvTheorems) e | ||
| match res with | ||
| | .rfl false => return .rfl | ||
| | _ => return res | ||
| else | ||
| return .rfl (← isCbvOpaque fnName) | ||
|
Comment on lines
90
to
97
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you even want to support theorem when `isOpaque = false?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! Imagine that we have a function that we have an optimisation for one set of inputs, but for the other ones, we would like to follow the standard path that involves using the associated equations/unfold equation. |
||
|
|
||
| def isOpaqueConst : Simproc := fun e => do | ||
| let .const constName _ := e | return .rfl | ||
| return .rfl (← isCbvOpaque constName) | ||
| let hasTheorems := (← getCbvEvalLemmas constName).isSome | ||
| if hasTheorems then | ||
| let res ← (tryCbvTheorems) e | ||
| match res with | ||
| | .rfl false => | ||
| return .rfl | ||
| | _ => return res | ||
| else | ||
| return .rfl (← isCbvOpaque constName) | ||
|
|
||
| def foldLit : Simproc := fun e => do | ||
| let some n := e.rawNatLit? | return .rfl | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| import CbvAttr.Tst |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| module | ||
|
|
||
| public def f3 (x : Nat) := | ||
| x + 1 | ||
|
|
||
| @[expose] public def f4 (x : Nat) := | ||
| x + 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| module | ||
|
|
||
| set_option cbv.warning false | ||
|
|
||
| @[cbv_opaque] public def f2 (x : Nat) := | ||
| x + 1 | ||
|
|
||
| private axiom myAx : f2 x = x + 1 | ||
|
|
||
| @[local cbv_eval] public theorem f2_spec : f2 x = x + 1 := myAx | ||
|
|
||
| example : f2 1 = 2 := by conv => lhs; cbv |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| module | ||
|
|
||
| set_option cbv.warning false | ||
|
|
||
| @[cbv_opaque] public def f5 (x : Nat) := | ||
| x + 1 | ||
|
|
||
| @[cbv_eval] private theorem f5_spec : f5 x = x + 1 := rfl | ||
|
|
||
| /- works locally -/ | ||
| example : f5 1 = 2 := by conv => lhs; cbv |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| module | ||
|
|
||
| set_option cbv.warning false | ||
|
|
||
| @[cbv_opaque] public def f1 (x : Nat) := | ||
| x + 1 | ||
|
|
||
| private axiom myAx : f1 x = x + 1 | ||
|
|
||
| @[cbv_eval] public theorem f1_spec : f1 x = x + 1 := myAx | ||
|
|
||
| example : f1 1 = 2 := by conv => lhs; cbv |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| module | ||
|
|
||
| import CbvAttr.PubliclyVisibleTheorem | ||
| import CbvAttr.PublicFunctionLocalTheorem | ||
| import CbvAttr.PublicFunction | ||
| import CbvAttr.PublicFunctionPrivateTheorem | ||
|
|
||
| set_option cbv.warning false | ||
|
|
||
| /- Function does not have an exposed body, but has a public theorem for unrolling it-/ | ||
| example : f1 1 = 2 := by conv => lhs; cbv | ||
|
|
||
| /- Function has an exposed body, public theorem for unrolling it, but the attribute is local -/ | ||
|
|
||
| /-- | ||
| error: unsolved goals | ||
| ⊢ f2 1 = 2 | ||
| -/ | ||
| #guard_msgs in | ||
| example : f2 1 = 2 := by conv => lhs; cbv | ||
|
|
||
| /- Function is public, but its body is not exposed -/ | ||
|
|
||
| /-- | ||
| error: unsolved goals | ||
| ⊢ f3 1 = 2 | ||
| -/ | ||
| #guard_msgs in | ||
| example : f3 1 = 2 := by conv => lhs; cbv | ||
|
|
||
| /- Public function, that has an exposed body -/ | ||
| example : f4 1 = 2 := by conv => lhs; cbv | ||
|
|
||
| /- Public function, private theorem-/ | ||
|
|
||
| /-- | ||
| error: unsolved goals | ||
| ⊢ f5 1 = 2 | ||
| -/ | ||
| #guard_msgs in | ||
| example : f5 1 = 2 := by conv => lhs; cbv |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| import Lake | ||
| open System Lake DSL | ||
|
|
||
| package user_attr | ||
| @[default_target] lean_lib CbvAttr |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| lean4 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| #!/usr/bin/env bash | ||
|
|
||
| rm -rf .lake/build | ||
| lake build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this PR is a draft, so I am not sure much attention to polish I should give, but here are some comments necessary. In particular what is
target? Is it always the head symbol of the LHS of the theorem? If so, is it worth keeping it as a separate field?(Generally reviewing is easier if there are already comments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, target is the extracted head symbol from LHS of the theorem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. We could move the logic and do this when adding a
Theoremobject to the state of the extension, by inferring the type of an expression and extracting the name of the appFn on the LHS.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, we want the name of appFn of the LHS of the theorem to be the key of the
NameMapstored in the expression, so the name of it should be known at the moment of callingaddon the extension (which is pure, and hence we cannot compute it inside ofadd)