Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Lean/Meta/Sym/Simp/Theorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ structure Theorem where
pattern : Pattern
/-- Right-hand side of the equation. -/
rhs : Expr
deriving Inhabited

instance : BEq Theorem where
beq thm₁ thm₂ := thm₁.expr == thm₂.expr
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Cbv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module
prelude
public import Lean.Meta.Tactic.Cbv.Main
public import Lean.Meta.Tactic.Cbv.Util
public import Lean.Meta.Tactic.Cbv.CbvEvalExt

public section

Expand Down
82 changes: 82 additions & 0 deletions src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Wojciech Różowski
-/
module
prelude
public import Lean.Data.NameMap
public import Lean.ScopedEnvExtension
public import Lean.Elab.InfoTree
public import Lean.Meta.Sym.Simp.Theorems

public section
namespace Lean.Meta.Sym.Simp

def Theorem.declName (thm : Theorem) : Name := thm.expr.getAppFn.constName!

def Theorem.isPrivate (thm : Theorem) : Bool := isPrivateName thm.declName

end Lean.Meta.Sym.Simp

namespace Lean.Meta.Tactic.Cbv
open Lean.Meta.Sym.Simp

structure CbvEvalEntry where
appFn : Name
thm : Theorem
deriving BEq, Inhabited
Comment on lines 25 to 28
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this PR is a draft, so I am not sure much attention to polish I should give, but here are some comments necessary. In particular what is target? Is it always the head symbol of the LHS of the theorem? If so, is it worth keeping it as a separate field?

(Generally reviewing is easier if there are already comments)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, target is the extracted head symbol from LHS of the theorem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. We could move the logic and do this when adding a Theorem object to the state of the extension, by inferring the type of an expression and extracting the name of the appFn on the LHS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no, we want the name of appFn of the LHS of the theorem to be the key of the NameMap stored in the expression, so the name of it should be known at the moment of calling add on the extension (which is pure, and hence we cannot compute it inside of add)


def mkCbvTheoremFromConst (declName : Name) : MetaM CbvEvalEntry := do
let cinfo ← getConstVal declName
let us := cinfo.levelParams.map mkLevelParam
let val := mkConst declName us
let type ← inferType val
unless (← isProp type) do throwError "{val} is not a theorem and thus cannot be marked with `cbv_eval` attribute"
let fnName ← forallTelescope type fun _ body => do
let some (_, lhs, _) := body.eq? | throwError "The conclusion {type} of theorem {val} is not an equality"
let appFn := lhs.getAppFn
let some constName := appFn.constName? | throwError "The left-hand side of a theorem {val} is not an application of a constant"
return constName
let thm ← mkTheoremFromDecl declName
return ⟨fnName, thm⟩

structure CbvEvalState where
lemmas : NameMap Theorems := {}
deriving Inhabited

def CbvEvalState.addEntry (s : CbvEvalState) (e : CbvEvalEntry) : CbvEvalState :=
let existing := (s.lemmas.find? e.appFn).getD {}
{ s with lemmas := s.lemmas.insert e.appFn (existing.insert e.thm) }

abbrev CbvEvalExtension := SimpleScopedEnvExtension CbvEvalEntry CbvEvalState

builtin_initialize cbvEvalExt : CbvEvalExtension ←
registerSimpleScopedEnvExtension {
name := `cbvEvalExt
initial := {}
addEntry := CbvEvalState.addEntry
exportEntry? := fun level entry => do
guard (level == .private || !entry.thm.isPrivate)
return entry
}

def getCbvEvalLemmas (target : Name) : CoreM (Option Theorems) := do
let s := cbvEvalExt.getState (← getEnv)
return (s.lemmas.find? target)

syntax (name := Parser.Attr.cbvEval) "cbv_eval" : attr

builtin_initialize
registerBuiltinAttribute {
ref := `cbvEvalAttr
name := `cbv_eval
descr := "Register a theorem as a rewrite rule for CBV evaluation of a given definition. \
Usage: @[cbv_eval] theorem ..."
applicationTime := AttributeApplicationTime.afterCompilation
add := fun lemmaName _ kind => do
let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName) {}
cbvEvalExt.add entry kind
}

end Lean.Meta.Tactic.Cbv
27 changes: 24 additions & 3 deletions src/Lean/Meta/Tactic/Cbv/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public import Lean.Meta.Sym.Simp.SimpM
public import Lean.Meta.Tactic.Cbv.Opaque
import Lean.Meta.Tactic.Cbv.Util
import Lean.Meta.Tactic.Cbv.TheoremsLookup
import Lean.Meta.Tactic.Cbv.CbvEvalExt
import Lean.Meta.Sym

namespace Lean.Meta.Tactic.Cbv
Expand Down Expand Up @@ -69,23 +70,43 @@ def betaReduce : Simproc := fun e => do
let new ← Sym.share new
return .step new (← Sym.mkEqRefl new)

def tryCbvTheorems : Simproc := fun e => do
let some fnName := e.getAppFn.constName? | return .rfl
let some evalLemmas ← getCbvEvalLemmas fnName | return .rfl
Theorems.rewrite evalLemmas (d := dischargeNone) e

def handleApp : Simproc := fun e => do
unless e.isApp do return .rfl
let fn := e.getAppFn
match fn with
| .const constName _ =>
let info ← getConstInfo constName
(guardSimproc (fun _ => info.hasValue) handleConstApp) <|> reduceRecMatcher <| e
tryCbvTheorems <|> (guardSimproc (fun _ => info.hasValue) handleConstApp) <|> reduceRecMatcher <| e
| .lam .. => betaReduce e
| _ => return .rfl

def isOpaqueApp : Simproc := fun e => do
let some fnName := e.getAppFn.constName? | return .rfl
return .rfl (← isCbvOpaque fnName)
let hasTheorems := (← getCbvEvalLemmas fnName).isSome
if hasTheorems then
let res ← (simpAppArgs >> tryCbvTheorems) e
match res with
| .rfl false => return .rfl
| _ => return res
else
return .rfl (← isCbvOpaque fnName)
Comment on lines 90 to 97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you even want to support theorem when `isOpaque = false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Imagine that we have a function that we have an optimisation for one set of inputs, but for the other ones, we would like to follow the standard path that involves using the associated equations/unfold equation.


def isOpaqueConst : Simproc := fun e => do
let .const constName _ := e | return .rfl
return .rfl (← isCbvOpaque constName)
let hasTheorems := (← getCbvEvalLemmas constName).isSome
if hasTheorems then
let res ← (tryCbvTheorems) e
match res with
| .rfl false =>
return .rfl
| _ => return res
else
return .rfl (← isCbvOpaque constName)

def foldLit : Simproc := fun e => do
let some n := e.rawNatLit? | return .rfl
Expand Down
4 changes: 4 additions & 0 deletions src/Lean/Meta/Tactic/Cbv/TheoremsLookup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ end Lean.Meta.Sym.Simp
namespace Lean.Meta.Tactic.Cbv
open Lean.Meta.Sym.Simp

/--
Get or create cached Theorems for function equations.
Retrieves equations via `getEqnsFor?` and caches the resulting Theorems object.
-/
public structure CbvTheoremsLookupState where
eqnTheorems : PHashMap Name Theorems := {}
unfoldTheorems : PHashMap Name Theorem := {}
Expand Down
49 changes: 43 additions & 6 deletions tests/lean/run/cbv1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,52 @@ example : Nat.brazilianFactorial 7 = 125411328000 := by

attribute [cbv_opaque] Std.DHashMap.emptyWithCapacity
attribute [cbv_opaque] Std.DHashMap.insert
attribute [cbv_opaque] Std.DHashMap.contains
attribute [cbv_opaque] Std.DHashMap.getEntry
attribute [cbv_opaque] Std.DHashMap.contains
attribute [cbv_eval Std.DHashMap.contains] Std.DHashMap.contains_emptyWithCapacity
attribute [cbv_eval Std.DHashMap.contains] Std.DHashMap.contains_insert

/--
error: unsolved goals
⊢ (Std.DHashMap.emptyWithCapacity.insert 5 3).contains 5 = true
-/
#guard_msgs in
example : ((Std.DHashMap.emptyWithCapacity : Std.DHashMap Nat (fun _ => Nat)).insert 5 3).contains 5 = true := by
conv =>
lhs
cbv

@[cbv_opaque] def opaque_const : Nat := Nat.zero

@[cbv_eval] theorem opaque_fn_spec : opaque_const = 0 := by rfl

example : opaque_const = 0 := by conv => lhs; cbv

def myAdd (m n : Nat) := match m with
| 0 => n
| m' + 1 => (myAdd m' n) + 1

@[cbv_eval] theorem myAdd_test : myAdd 22 23 = 45 := by rfl

theorem fast_path : myAdd 22 23 = 45 := by conv => lhs; cbv

/--
info: theorem fast_path : myAdd 22 23 = 45 :=
Eq.mpr
(id
((fun a a_1 e_a =>
Eq.rec (motive := fun a_2 e_a => ∀ (a_3 : Nat), (a = a_3) = (a_2 = a_3)) (fun a_2 => Eq.refl (a = a_2)) e_a)
(myAdd 22 23) 45 (Eq.trans myAdd_test (Eq.refl 45)) 45))
(Eq.refl 45)
-/
#guard_msgs in
#print fast_path

theorem slow_path : myAdd 0 1 = 1 := by conv => lhs; cbv

/--
info: theorem slow_path : myAdd 0 1 = 1 :=
Eq.mpr
(id
((fun a a_1 e_a =>
Eq.rec (motive := fun a_2 e_a => ∀ (a_3 : Nat), (a = a_3) = (a_2 = a_3)) (fun a_2 => Eq.refl (a = a_2)) e_a)
(myAdd 0 1) 1 (Eq.trans (myAdd.eq_1 1) (Eq.refl 1)) 1))
(Eq.refl 1)
-/
#guard_msgs in
#print slow_path
1 change: 1 addition & 0 deletions tests/pkg/cbv_attr/CbvAttr.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import CbvAttr.Tst
7 changes: 7 additions & 0 deletions tests/pkg/cbv_attr/CbvAttr/PublicFunction.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module

public def f3 (x : Nat) :=
x + 1

@[expose] public def f4 (x : Nat) :=
x + 1
12 changes: 12 additions & 0 deletions tests/pkg/cbv_attr/CbvAttr/PublicFunctionLocalTheorem.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module

set_option cbv.warning false

@[cbv_opaque] public def f2 (x : Nat) :=
x + 1

private axiom myAx : f2 x = x + 1

@[local cbv_eval] public theorem f2_spec : f2 x = x + 1 := myAx

example : f2 1 = 2 := by conv => lhs; cbv
11 changes: 11 additions & 0 deletions tests/pkg/cbv_attr/CbvAttr/PublicFunctionPrivateTheorem.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module

set_option cbv.warning false

@[cbv_opaque] public def f5 (x : Nat) :=
x + 1

@[cbv_eval] private theorem f5_spec : f5 x = x + 1 := rfl

/- works locally -/
example : f5 1 = 2 := by conv => lhs; cbv
12 changes: 12 additions & 0 deletions tests/pkg/cbv_attr/CbvAttr/PubliclyVisibleTheorem.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module

set_option cbv.warning false

@[cbv_opaque] public def f1 (x : Nat) :=
x + 1

private axiom myAx : f1 x = x + 1

@[cbv_eval] public theorem f1_spec : f1 x = x + 1 := myAx

example : f1 1 = 2 := by conv => lhs; cbv
41 changes: 41 additions & 0 deletions tests/pkg/cbv_attr/CbvAttr/Tst.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
module

import CbvAttr.PubliclyVisibleTheorem
import CbvAttr.PublicFunctionLocalTheorem
import CbvAttr.PublicFunction
import CbvAttr.PublicFunctionPrivateTheorem

set_option cbv.warning false

/- Function does not have an exposed body, but has a public theorem for unrolling it-/
example : f1 1 = 2 := by conv => lhs; cbv

/- Function has an exposed body, public theorem for unrolling it, but the attribute is local -/

/--
error: unsolved goals
⊢ f2 1 = 2
-/
#guard_msgs in
example : f2 1 = 2 := by conv => lhs; cbv

/- Function is public, but its body is not exposed -/

/--
error: unsolved goals
⊢ f3 1 = 2
-/
#guard_msgs in
example : f3 1 = 2 := by conv => lhs; cbv

/- Public function, that has an exposed body -/
example : f4 1 = 2 := by conv => lhs; cbv

/- Public function, private theorem-/

/--
error: unsolved goals
⊢ f5 1 = 2
-/
#guard_msgs in
example : f5 1 = 2 := by conv => lhs; cbv
5 changes: 5 additions & 0 deletions tests/pkg/cbv_attr/lakefile.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import Lake
open System Lake DSL

package user_attr
@[default_target] lean_lib CbvAttr
1 change: 1 addition & 0 deletions tests/pkg/cbv_attr/lean-toolchain
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lean4
4 changes: 4 additions & 0 deletions tests/pkg/cbv_attr/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env bash

rm -rf .lake/build
lake build
Loading