forked from leanprover/lean4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTrie.lean
202 lines (180 loc) · 5.93 KB
/
Trie.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
/-
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Sebastian Ullrich, Leonardo de Moura, Joachim Breitner
A string trie data structure, used for tokenizing the Lean language
-/
import Lean.Data.Format
namespace Lean
namespace Data
/-
## Implementation notes
Tries have typically many nodes with small degree, where a linear scan
through the (compact) `ByteArray` is faster than using binary search or
search trees like `RBTree`.
Moreover, many nodes have degree 1, which justifies the special case `Node1`
constructor.
The code would be a bit less repetitive if we used something like the following
```
mutual
def Trie α := Option α × ByteAssoc α
inductive ByteAssoc α where
| leaf : Trie α
| node1 : UInt8 → Trie α → Trie α
| node : ByteArray → Array (Trie α) → Trie α
end
```
but that would come at the cost of extra indirections.
-/
/-- A Trie is a key-value store where the keys are of type `String`,
and the internal structure is a tree that branches on the bytes of the string. -/
inductive Trie (α : Type) where
| leaf : Option α → Trie α
| node1 : Option α → UInt8 → Trie α → Trie α
| node : Option α → ByteArray → Array (Trie α) → Trie α
namespace Trie
variable {α : Type}
/-- The empty `Trie` -/
def empty : Trie α := leaf none
instance : EmptyCollection (Trie α) :=
⟨empty⟩
instance : Inhabited (Trie α) where
default := empty
/-- Insert or update the value at a the given key `s`. -/
partial def upsert (t : Trie α) (s : String) (f : Option α → α) : Trie α :=
let rec insertEmpty (i : Nat) : Trie α :=
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
let t := insertEmpty (i + 1)
node1 none c t
else
leaf (f .none)
let rec loop
| i, leaf v =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
let t := insertEmpty (i + 1)
node1 v c t
else
leaf (f v)
| i, node1 v c' t' =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then node1 v c' (loop (i + 1) t')
else
let t := insertEmpty (i + 1)
node v (.mk #[c, c']) #[t, t']
else
node1 (f v) c' t'
| i, node v cs ts =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none =>
let t := insertEmpty (i + 1)
node v (cs.push c) (ts.push t)
| some idx =>
node v cs (ts.modify idx (loop (i + 1)))
else
node (f v) cs ts
loop 0 t
/-- Inserts a value at a the given key `s`, overriding an existing value if present. -/
partial def insert (t : Trie α) (s : String) (val : α) : Trie α :=
upsert t s (fun _ => val)
/-- Looks up a value at the given key `s`. -/
partial def find? (t : Trie α) (s : String) : Option α :=
let rec loop
| i, leaf val =>
if i < s.utf8ByteSize then
none
else
val
| i, node1 val c' t' =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then loop (i + 1) t'
else none
else
val
| i, node val cs ts =>
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none => none
| some idx => loop (i + 1) (ts.get! idx)
else
val
loop 0 t
/-- Returns an `Array` of all values in the trie, in no particular order. -/
partial def values (t : Trie α) : Array α := go t |>.run #[] |>.2
where
go : Trie α → StateM (Array α) Unit
| leaf a? => do
if let some a := a? then
modify (·.push a)
| node1 a? _ t' => do
if let some a := a? then
modify (·.push a)
go t'
| node a? _ ts => do
if let some a := a? then
modify (·.push a)
ts.forM fun t' => go t'
/-- Returns all values whose key have the given string `pre` as a prefix, in no particular order. -/
partial def findPrefix (t : Trie α) (pre : String) : Array α := go t 0
where
go (t : Trie α) (i : Nat) : Array α :=
if h : i < pre.utf8ByteSize then
let c := pre.getUtf8Byte i h
match t with
| leaf _val => .empty
| node1 _val c' t' =>
if c == c'
then go t' (i + 1)
else .empty
| node _val cs ts =>
match cs.findIdx? (· == c) with
| none => .empty
| some idx => go (ts.get! idx) (i + 1)
else
t.values
/-- Find the longest _key_ in the trie that is contained in the given string `s` at position `i`,
and return the associated value. -/
partial def matchPrefix (s : String) (t : Trie α) (i : String.Pos) : Option α :=
let rec loop
| leaf v, _, res =>
if v.isSome then v else res
| node1 v c' t', i, res =>
let res := if v.isSome then v else res
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
if c == c'
then loop t' (i + 1) res
else res
else
res
| node v cs ts, i, res =>
let res := if v.isSome then v else res
if h : i < s.utf8ByteSize then
let c := s.getUtf8Byte i h
match cs.findIdx? (· == c) with
| none => res
| some idx => loop (ts.get! idx) (i + 1) res
else
res
loop t i.byteIdx none
private partial def toStringAux {α : Type} : Trie α → List Format
| leaf _ => []
| node1 _ c t =>
[ format (repr c), Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t ]
| node _ cs ts =>
List.join $ List.zipWith (fun c t =>
[ format (repr c), (Format.group $ Format.nest 4 $ flip Format.joinSep Format.line $ toStringAux t) ]
) cs.toList ts.toList
instance {α : Type} : ToString (Trie α) :=
⟨fun t => (flip Format.joinSep Format.line $ toStringAux t).pretty⟩
end Trie
end Data
end Lean