forked from leanprover/lean4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUnificationHint.lean
145 lines (126 loc) · 5.34 KB
/
UnificationHint.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.ScopedEnvExtension
import Lean.Util.Recognizers
import Lean.Meta.DiscrTree
import Lean.Meta.SynthInstance
namespace Lean.Meta
abbrev UnificationHintKey := DiscrTree.Key
structure UnificationHintEntry where
keys : Array UnificationHintKey
val : Name
deriving Inhabited
abbrev UnificationHintTree := DiscrTree Name
structure UnificationHints where
discrTree : UnificationHintTree := DiscrTree.empty
deriving Inhabited
instance : ToFormat UnificationHints where
format h := format h.discrTree
def UnificationHints.config : WhnfCoreConfig := { iota := false, proj := .no }
def UnificationHints.add (hints : UnificationHints) (e : UnificationHintEntry) : UnificationHints :=
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val config }
builtin_initialize unificationHintExtension : SimpleScopedEnvExtension UnificationHintEntry UnificationHints ←
registerSimpleScopedEnvExtension {
addEntry := UnificationHints.add
initial := {}
}
structure UnificationConstraint where
lhs : Expr
rhs : Expr
structure UnificationHint where
pattern : UnificationConstraint
constraints : List UnificationConstraint
private partial def decodeUnificationHint (e : Expr) : ExceptT MessageData Id UnificationHint := do
decode e #[]
where
decodeConstraint (e : Expr) : ExceptT MessageData Id UnificationConstraint :=
match e.eq? with
| some (_, lhs, rhs) => return UnificationConstraint.mk lhs rhs
| none => throw m!"invalid unification hint constraint, unexpected term{indentExpr e}"
decode (e : Expr) (cs : Array UnificationConstraint) : ExceptT MessageData Id UnificationHint := do
match e with
| Expr.forallE _ d b _ => do
let c ← decodeConstraint d
if b.hasLooseBVars then
throw m!"invalid unification hint constraint, unexpected dependency{indentExpr e}"
decode b (cs.push c)
| _ => do
let p ← decodeConstraint e
return { pattern := p, constraints := cs.toList }
private partial def validateHint (hint : UnificationHint) : MetaM Unit := do
hint.constraints.forM fun c => do
unless (← isDefEq c.lhs c.rhs) do
throwError "invalid unification hint, failed to unify constraint left-hand-side{indentExpr c.lhs}\nwith right-hand-side{indentExpr c.rhs}"
unless (← isDefEq hint.pattern.lhs hint.pattern.rhs) do
throwError "invalid unification hint, failed to unify pattern left-hand-side{indentExpr hint.pattern.lhs}\nwith right-hand-side{indentExpr hint.pattern.rhs}"
def addUnificationHint (declName : Name) (kind : AttributeKind) : MetaM Unit :=
withNewMCtxDepth do
let info ← getConstInfo declName
match info.value? with
| none => throwError "invalid unification hint, it must be a definition"
| some val =>
let (_, _, body) ← lambdaMetaTelescope val
match decodeUnificationHint body with
| Except.error msg => throwError msg
| Except.ok hint =>
let keys ← DiscrTree.mkPath hint.pattern.lhs UnificationHints.config
validateHint hint
unificationHintExtension.add { keys := keys, val := declName } kind
builtin_initialize
registerBuiltinAttribute {
name := `unification_hint
descr := "unification hint"
add := fun declName stx kind => do
Attribute.Builtin.ensureNoArgs stx
discard <| addUnificationHint declName kind |>.run
}
def tryUnificationHints (t s : Expr) : MetaM Bool := do
trace[Meta.isDefEq.hint] "{t} =?= {s}"
unless (← read).config.unificationHints do
return false
if t.isMVar then
return false
let hints := unificationHintExtension.getState (← getEnv)
let candidates ← hints.discrTree.getMatch t UnificationHints.config
for candidate in candidates do
if (← tryCandidate candidate) then
return true
return false
where
isDefEqPattern p e :=
withReducible <| Meta.isExprDefEqAux p e
tryCandidate candidate : MetaM Bool :=
withTraceNode `Meta.isDefEq.hint
(return m!"{exceptBoolEmoji ·} hint {candidate} at {t} =?= {s}") do
checkpointDefEq do
let cinfo ← getConstInfo candidate
let us ← cinfo.levelParams.mapM fun _ => mkFreshLevelMVar
let val ← instantiateValueLevelParams cinfo us
let (xs, bis, body) ← lambdaMetaTelescope val
let hint? ← withConfig (fun cfg => { cfg with unificationHints := false }) do
match decodeUnificationHint body with
| Except.error _ => return none
| Except.ok hint =>
if (← isDefEqPattern hint.pattern.lhs t <&&> isDefEqPattern hint.pattern.rhs s) then
return some hint
else
return none
match hint? with
| none => return false
| some hint =>
trace[Meta.isDefEq.hint] "{candidate} succeeded, applying constraints"
for c in hint.constraints do
unless (← Meta.isExprDefEqAux c.lhs c.rhs) do
return false
for x in xs, bi in bis do
if bi == BinderInfo.instImplicit then
match (← trySynthInstance (← inferType x)) with
| LOption.some val => unless (← isDefEq x val) do return false
| _ => return false
return true
builtin_initialize
registerTraceClass `Meta.isDefEq.hint
end Lean.Meta