@@ -186,10 +186,123 @@ module Partition = struct
186
186
let find x = Option.Monad_infix. (Hashtbl. find comps x >> = find_root) in
187
187
{roots; groups; find}
188
188
189
+ let equiv t x y = Option. equal Equiv. equal (t.find x) (t.find y)
190
+
191
+ (* The trivial partition with a single class, or zero if elts is empty *)
192
+ let trivial elts =
193
+ let head = Set. choose elts in
194
+ match head with
195
+ | None ->
196
+ let roots = [||] in
197
+ let groups = [||] in
198
+ let find _ = None in
199
+ {roots; groups; find}
200
+ | Some h ->
201
+ let roots = Array. create ~len: 1 h in
202
+ let groups = Array. create ~len: 1 (create_set elts) in
203
+ let find x = if Set. mem elts x then Some 0 else None in
204
+ {roots; groups; find}
205
+
206
+ (* The discrete partition with one class per element *)
207
+ let discrete elts =
208
+ let comparator = Set. comparator elts in
209
+ let {Comparator. compare} = comparator in
210
+ (* Produces a sorted array per the spec *)
211
+ let roots = Set. to_array elts in
212
+ let groups = elts |> Set. to_array |>
213
+ Array. map ~f: (fun x ->
214
+ object
215
+ method enum = Seq. return x
216
+ method mem y = compare x y = 0
217
+ end )
218
+ in
219
+ let find x = Array. binary_search roots ~compare `First_equal_to x in
220
+ {roots; groups; find}
221
+
222
+ (* Takes a partition and a congruence and splits each equivalence class into
223
+ elements related by the congruence.
224
+ Takes in a comparison function to test for membership in each class.
225
+ *)
226
+ let refine (type elt ) t ~equiv ~cmp =
227
+ let module T = Comparator. Make (struct
228
+ type t = elt
229
+ let compare = cmp
230
+ let sexp_of_t = sexp_of_opaque
231
+ end ) in
232
+ let comparator = T. comparator in
233
+ let refine_group g =
234
+ let rec insert elt output input = match input with
235
+ | [] -> Set. singleton ~comparator elt :: output
236
+ | group :: input ->
237
+ if equiv (Set. choose_exn group) elt
238
+ then List. rev_append ((Set. add group elt) :: output) input
239
+ else insert elt (group::output) input in
240
+ Seq. fold g#enum ~init: [] ~f: (fun groups elt ->
241
+ insert elt [] groups) |> List. rev_map ~f: create_set in
242
+ let groups_list = Array. fold t.groups ~init: [] ~f: (fun seqs g -> refine_group g @ seqs) in
243
+ let groups = Array. of_list groups_list in
244
+ Array. sort groups ~cmp: (fun s1 s2 ->
245
+ let h1 = Seq. hd_exn s1#enum in
246
+ let h2 = Seq. hd_exn s2#enum in
247
+ cmp h1 h2);
248
+ let roots = Array. map ~f: (fun s -> Seq. hd_exn s#enum) groups in
249
+ let find x = Array. binary_search roots ~compare: cmp `First_equal_to x in
250
+ {roots; groups; find}
251
+
252
+ (* Take two elements and combine their classes if both have a class,
253
+ do nothing otherwise *)
254
+ let union t x y =
255
+ (* Assuming i < j,
256
+ create a new array a', such that
257
+ Array.length a' = Array.length a - 1 and
258
+ a'[k] = a[k] when k < i
259
+ a'[i] = x
260
+ a'[j] = a[j+1]
261
+ a'[j+1] = a[j+2]
262
+ ...
263
+ *)
264
+ let array_replace a i j x =
265
+ assert (i < j && Array. length a > 0 );
266
+ Array. init (Array. length a - 1 )
267
+ ~f: (fun n -> if n < i then a.(n)
268
+ else if n = i then x
269
+ else if n < j then a.(n)
270
+ else a.(n+ 1 ))
271
+ in
272
+ if equiv t x y then t
273
+ else
274
+ match t.find x, t.find y with
275
+ | None , _ | _ ,None -> t
276
+ | Some i_x , Some i_y ->
277
+ let g_x, g_y = t.groups.(i_x), t.groups.(i_y) in
278
+ let u_g = object
279
+ method enum =
280
+ let s_x, s_y = g_x#enum, g_y#enum in
281
+ Seq. append s_x s_y
282
+ method mem x = g_x#mem x || g_y#mem x
283
+ end
284
+ in
285
+ (* min biased root *)
286
+ let i = Int. min i_x i_y in
287
+ let j = Int. max i_x i_y in
288
+ let u_root = t.roots.(i) in
289
+ let roots = array_replace t.roots i j u_root in
290
+ let groups = array_replace t.groups i j u_g in
291
+ let find x = Option.Monad_infix. (
292
+ t.find x >> |
293
+ (* By cases: if n < i or i < n < j then it is in one one
294
+ of the original classes, otherwise n = i, then one
295
+ should return i (as the class still contains these
296
+ elements) or n = j, in wich case these elements are now
297
+ in class i, or n > j, in which case we must left-shift them *)
298
+ fun n -> if n < j then n
299
+ else if n = j then i
300
+ else n - 1 ) in
301
+ {roots; groups; find}
302
+
189
303
let nth_group t n = Group. create t.roots.(n) t.groups.(n) n
190
304
let groups t = Seq. (range 0 (Array. length t.roots) >> | nth_group t)
191
305
let group t x = Option. (t.find x >> | nth_group t)
192
- let equiv t x y = Option. equal Equiv. equal (t.find x) (t.find y)
193
306
let number_of_groups t = Array. length t.roots
194
307
let of_equiv t i =
195
308
if i > = 0 && i < Array. length t.roots
0 commit comments