forked from leanprover/lean4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPersistentHashMap.lean
350 lines (289 loc) · 14.5 KB
/
PersistentHashMap.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
namespace Lean
universe u v w w'
namespace PersistentHashMap
inductive Entry (α : Type u) (β : Type v) (σ : Type w) where
| entry (key : α) (val : β) : Entry α β σ
| ref (node : σ) : Entry α β σ
| null : Entry α β σ
instance {α β σ} : Inhabited (Entry α β σ) := ⟨Entry.null⟩
inductive Node (α : Type u) (β : Type v) : Type (max u v) where
| entries (es : Array (Entry α β (Node α β))) : Node α β
| collision (ks : Array α) (vs : Array β) (h : ks.size = vs.size) : Node α β
instance {α β} : Inhabited (Node α β) := ⟨Node.entries #[]⟩
abbrev shift : USize := 5
abbrev branching : USize := USize.ofNat (2 ^ shift.toNat)
abbrev maxDepth : USize := 7
abbrev maxCollisions : Nat := 4
def mkEmptyEntriesArray {α β} : Array (Entry α β (Node α β)) :=
(Array.mkArray PersistentHashMap.branching.toNat PersistentHashMap.Entry.null)
end PersistentHashMap
structure PersistentHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] where
root : PersistentHashMap.Node α β := PersistentHashMap.Node.entries PersistentHashMap.mkEmptyEntriesArray
size : Nat := 0
abbrev PHashMap (α : Type u) (β : Type v) [BEq α] [Hashable α] := PersistentHashMap α β
namespace PersistentHashMap
def empty [BEq α] [Hashable α] : PersistentHashMap α β := {}
def isEmpty [BEq α] [Hashable α] (m : PersistentHashMap α β) : Bool :=
m.size == 0
instance [BEq α] [Hashable α] : Inhabited (PersistentHashMap α β) := ⟨{}⟩
def mkEmptyEntries {α β} : Node α β :=
Node.entries mkEmptyEntriesArray
abbrev mul2Shift (i : USize) (shift : USize) : USize := i.shiftLeft shift
abbrev div2Shift (i : USize) (shift : USize) : USize := i.shiftRight shift
abbrev mod2Shift (i : USize) (shift : USize) : USize := USize.land i ((USize.shiftLeft 1 shift) - 1)
inductive IsCollisionNode : Node α β → Prop where
| mk (keys : Array α) (vals : Array β) (h : keys.size = vals.size) : IsCollisionNode (Node.collision keys vals h)
abbrev CollisionNode (α β) := { n : Node α β // IsCollisionNode n }
inductive IsEntriesNode : Node α β → Prop where
| mk (entries : Array (Entry α β (Node α β))) : IsEntriesNode (Node.entries entries)
abbrev EntriesNode (α β) := { n : Node α β // IsEntriesNode n }
private theorem size_set {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (i : Fin ks.size) (j : Fin vs.size) (k : α) (v : β)
: (ks.set i k).size = (vs.set j v).size := by
simp [h]
private theorem size_push {ks : Array α} {vs : Array β} (h : ks.size = vs.size) (k : α) (v : β) : (ks.push k).size = (vs.push v).size := by
simp [h]
partial def insertAtCollisionNodeAux [BEq α] : CollisionNode α β → Nat → α → β → CollisionNode α β
| n@⟨Node.collision keys vals heq, _⟩, i, k, v =>
if h : i < keys.size then
let idx : Fin keys.size := ⟨i, h⟩;
let k' := keys.get idx;
if k == k' then
let j : Fin vals.size := ⟨i, by rw [←heq]; assumption⟩
⟨Node.collision (keys.set idx k) (vals.set j v) (size_set heq idx j k v), IsCollisionNode.mk _ _ _⟩
else insertAtCollisionNodeAux n (i+1) k v
else
⟨Node.collision (keys.push k) (vals.push v) (size_push heq k v), IsCollisionNode.mk _ _ _⟩
| ⟨Node.entries _, h⟩, _, _, _ => False.elim (nomatch h)
def insertAtCollisionNode [BEq α] : CollisionNode α β → α → β → CollisionNode α β :=
fun n k v => insertAtCollisionNodeAux n 0 k v
def getCollisionNodeSize : CollisionNode α β → Nat
| ⟨Node.collision keys _ _, _⟩ => keys.size
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
def mkCollisionNode (k₁ : α) (v₁ : β) (k₂ : α) (v₂ : β) : Node α β :=
let ks : Array α := Array.mkEmpty maxCollisions
let ks := (ks.push k₁).push k₂
let vs : Array β := Array.mkEmpty maxCollisions
let vs := (vs.push v₁).push v₂
Node.collision ks vs rfl
partial def insertAux [BEq α] [Hashable α] : Node α β → USize → USize → α → β → Node α β
| Node.collision keys vals heq, _, depth, k, v =>
let newNode := insertAtCollisionNode ⟨Node.collision keys vals heq, IsCollisionNode.mk _ _ _⟩ k v
if depth >= maxDepth || getCollisionNodeSize newNode < maxCollisions then newNode.val
else match newNode with
| ⟨Node.entries _, h⟩ => False.elim (nomatch h)
| ⟨Node.collision keys vals heq, _⟩ =>
let rec traverse (i : Nat) (entries : Node α β) : Node α β :=
if h : i < keys.size then
let k := keys[i]
have : i < vals.size := heq ▸ h
let v := vals[i]
let h := hash k |>.toUSize
let h := div2Shift h (shift * (depth - 1))
traverse (i+1) (insertAux entries h depth k v)
else
entries
traverse 0 mkEmptyEntries
| Node.entries entries, h, depth, k, v =>
let j := (mod2Shift h shift).toNat
Node.entries $ entries.modify j fun entry =>
match entry with
| Entry.null => Entry.entry k v
| Entry.ref node => Entry.ref $ insertAux node (div2Shift h shift) (depth+1) k v
| Entry.entry k' v' =>
if k == k' then Entry.entry k v
else Entry.ref $ mkCollisionNode k' v' k v
def insert {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → β → PersistentHashMap α β
| { root := n, size := sz }, k, v => { root := insertAux n (hash k |>.toUSize) 1 k v, size := sz + 1 }
partial def findAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option β :=
if h : i < keys.size then
let k' := keys[i]
have : i < vals.size := by rw [←heq]; assumption
if k == k' then some vals[i]
else findAtAux keys vals heq (i+1) k
else none
partial def findAux [BEq α] : Node α β → USize → α → Option β
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => none
| Entry.ref node => findAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some v else none
| Node.collision keys vals heq, _, k => findAtAux keys vals heq 0 k
def find? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option β
| { root := n, .. }, k => findAux n (hash k |>.toUSize) k
instance {_ : BEq α} {_ : Hashable α} : GetElem (PersistentHashMap α β) α (Option β) fun _ _ => True where
getElem m i _ := m.find? i
@[inline] def findD {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) (a : α) (b₀ : β) : β :=
(m.find? a).getD b₀
@[inline] def find! {_ : BEq α} {_ : Hashable α} [Inhabited β] (m : PersistentHashMap α β) (a : α) : β :=
match m.find? a with
| some b => b
| none => panic! "key is not in the map"
partial def findEntryAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Option (α × β) :=
if h : i < keys.size then
let k' := keys[i]
have : i < vals.size := by rw [←heq]; assumption
if k == k' then some (k', vals[i])
else findEntryAtAux keys vals heq (i+1) k
else none
partial def findEntryAux [BEq α] : Node α β → USize → α → Option (α × β)
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => none
| Entry.ref node => findEntryAux node (div2Shift h shift) k
| Entry.entry k' v => if k == k' then some (k', v) else none
| Node.collision keys vals heq, _, k => findEntryAtAux keys vals heq 0 k
def findEntry? {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → Option (α × β)
| { root := n, .. }, k => findEntryAux n (hash k |>.toUSize) k
partial def containsAtAux [BEq α] (keys : Array α) (vals : Array β) (heq : keys.size = vals.size) (i : Nat) (k : α) : Bool :=
if h : i < keys.size then
let k' := keys[i]
if k == k' then true
else containsAtAux keys vals heq (i+1) k
else false
partial def containsAux [BEq α] : Node α β → USize → α → Bool
| Node.entries entries, h, k =>
let j := (mod2Shift h shift).toNat
match entries.get! j with
| Entry.null => false
| Entry.ref node => containsAux node (div2Shift h shift) k
| Entry.entry k' _ => k == k'
| Node.collision keys vals heq, _, k => containsAtAux keys vals heq 0 k
def contains [BEq α] [Hashable α] : PersistentHashMap α β → α → Bool
| { root := n, .. }, k => containsAux n (hash k |>.toUSize) k
partial def isUnaryEntries (a : Array (Entry α β (Node α β))) (i : Nat) (acc : Option (α × β)) : Option (α × β) :=
if h : i < a.size then
match a[i] with
| Entry.null => isUnaryEntries a (i+1) acc
| Entry.ref _ => none
| Entry.entry k v =>
match acc with
| none => isUnaryEntries a (i+1) (some (k, v))
| some _ => none
else acc
def isUnaryNode : Node α β → Option (α × β)
| Node.entries entries => isUnaryEntries entries 0 none
| Node.collision keys vals heq =>
if h : 1 = keys.size then
have : 0 < keys.size := by rw [←h]; decide
have : 0 < vals.size := by rw [←heq]; assumption
some (keys[0], vals[0])
else
none
partial def eraseAux [BEq α] : Node α β → USize → α → Node α β × Bool
| n@(Node.collision keys vals heq), _, k =>
match keys.indexOf? k with
| some idx =>
let ⟨keys', keq⟩ := keys.eraseIdx' idx
let ⟨vals', veq⟩ := vals.eraseIdx' (Eq.ndrec idx heq)
have : keys.size - 1 = vals.size - 1 := by rw [heq]
(Node.collision keys' vals' (keq.trans (this.trans veq.symm)), true)
| none => (n, false)
| n@(Node.entries entries), h, k =>
let j := (mod2Shift h shift).toNat
let entry := entries.get! j
match entry with
| Entry.null => (n, false)
| Entry.entry k' _ =>
if k == k' then (Node.entries (entries.set! j Entry.null), true) else (n, false)
| Entry.ref node =>
let entries := entries.set! j Entry.null
let (newNode, deleted) := eraseAux node (div2Shift h shift) k
if !deleted then (n, false)
else match isUnaryNode newNode with
| none => (Node.entries (entries.set! j (Entry.ref newNode)), true)
| some (k, v) => (Node.entries (entries.set! j (Entry.entry k v)), true)
def erase {_ : BEq α} {_ : Hashable α} : PersistentHashMap α β → α → PersistentHashMap α β
| { root := n, size := sz }, k =>
let h := hash k |>.toUSize
let (n, del) := eraseAux n h k
{ root := n, size := if del then sz - 1 else sz }
section
variable {m : Type w → Type w'} [Monad m]
variable {σ : Type w}
partial def foldlMAux (f : σ → α → β → m σ) : Node α β → σ → m σ
| Node.collision keys vals heq, acc =>
let rec traverse (i : Nat) (acc : σ) : m σ := do
if h : i < keys.size then
let k := keys[i]
have : i < vals.size := heq ▸ h
let v := vals[i]
traverse (i+1) (← f acc k v)
else
pure acc
traverse 0 acc
| Node.entries entries, acc => entries.foldlM (fun acc entry =>
match entry with
| Entry.null => pure acc
| Entry.entry k v => f acc k v
| Entry.ref node => foldlMAux f node acc)
acc
def foldlM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σ → α → β → m σ) (init : σ) : m σ :=
foldlMAux f map.root init
def forM {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : α → β → m PUnit) : m PUnit :=
map.foldlM (fun _ => f) ⟨⟩
def foldl {_ : BEq α} {_ : Hashable α} (map : PersistentHashMap α β) (f : σ → α → β → σ) (init : σ) : σ :=
Id.run <| map.foldlM f init
protected def forIn {_ : BEq α} {_ : Hashable α} [Monad m]
(map : PersistentHashMap α β) (init : σ) (f : α × β → σ → m (ForInStep σ)) : m σ := do
let intoError : ForInStep σ → Except σ σ
| .done s => .error s
| .yield s => .ok s
let result ← foldlM (m := ExceptT σ m) map (init := init) fun s a b =>
(intoError <$> f (a, b) s : m _)
match result with
| .ok s | .error s => pure s
instance {_ : BEq α} {_ : Hashable α} : ForIn m (PersistentHashMap α β) (α × β) where
forIn := PersistentHashMap.forIn
end
partial def mapMAux {α : Type u} {β : Type v} {σ : Type u} {m : Type u → Type w} [Monad m] (f : β → m σ) (n : Node α β) : m (Node α σ) := do
match n with
| .collision keys vals heq =>
let ⟨vals', h⟩ ← vals.mapM' f
return .collision keys vals' (h ▸ heq)
| .entries entries =>
let entries' ← entries.mapM fun
| .null => return .null
| .entry k v => return .entry k (← f v)
| .ref node => return .ref (← mapMAux f node)
return .entries entries'
def mapM {α : Type u} {β : Type v} {σ : Type u} {m : Type u → Type w} [Monad m] {_ : BEq α} {_ : Hashable α} (pm : PersistentHashMap α β) (f : β → m σ) : m (PersistentHashMap α σ) := do
let root ← mapMAux f pm.root
return { pm with root }
def map {α : Type u} {β : Type v} {σ : Type u} {_ : BEq α} {_ : Hashable α} (pm : PersistentHashMap α β) (f : β → σ) : PersistentHashMap α σ :=
Id.run <| pm.mapM f
def toList {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : List (α × β) :=
m.foldl (init := []) fun ps k v => (k, v) :: ps
structure Stats where
numNodes : Nat := 0
numNull : Nat := 0
numCollisions : Nat := 0
maxDepth : Nat := 0
partial def collectStats : Node α β → Stats → Nat → Stats
| Node.collision keys _ _, stats, depth =>
{ stats with
numNodes := stats.numNodes + 1,
numCollisions := stats.numCollisions + keys.size - 1,
maxDepth := Nat.max stats.maxDepth depth }
| Node.entries entries, stats, depth =>
let stats :=
{ stats with
numNodes := stats.numNodes + 1,
maxDepth := Nat.max stats.maxDepth depth }
entries.foldl (fun stats entry =>
match entry with
| Entry.null => { stats with numNull := stats.numNull + 1 }
| Entry.ref node => collectStats node stats (depth + 1)
| Entry.entry _ _ => stats)
stats
def stats {_ : BEq α} {_ : Hashable α} (m : PersistentHashMap α β) : Stats :=
collectStats m.root {} 1
def Stats.toString (s : Stats) : String :=
s!"\{ nodes := {s.numNodes}, null := {s.numNull}, collisions := {s.numCollisions}, depth := {s.maxDepth}}"
instance : ToString Stats := ⟨Stats.toString⟩