Skip to content

Commit ded29fd

Browse files
codyrouxivg
authored andcommitted
implements some helper functions for creating and manipulating partit… (#892)
* implements some helper functions for creating and manipulating partitions * Address various issues wrt performance and clarity * Base test structure for partitions * Change merge to union, run ocp-indent on all modified files * Write a few tests for the partition module * Write a few tests for the partition module * Fix parenthesizing bug in tests
1 parent d2a18af commit ded29fd

File tree

4 files changed

+184
-2
lines changed

4 files changed

+184
-2
lines changed

lib/graphlib/graphlib.mli

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,27 @@ module Std : sig
491491

492492
type 'a t = 'a partition
493493

494+
(** [trivial s] creates the trivial partition with a single
495+
equivalence class containing every member of [s] *)
496+
val trivial : ('a, 'b) Set.t -> 'a t
497+
498+
(** [discrete s] returns the partition with one class per element of [s] *)
499+
val discrete : ('a, 'b) Set.t -> 'a t
500+
501+
(** [refine p ~rel ~comp] takes a partition [p], and refines it
502+
according to the equivalence relation [r], so that the
503+
resulting partition corresponds to the classes of [r], assuming
504+
that those classes are finer that the original [p].
505+
506+
Takes an additional [comp] argument to compare for equality
507+
within the equivalence classes. *)
508+
val refine : 'a t -> equiv:('a -> 'a -> bool) -> cmp:('a -> 'a -> int) -> 'a t
509+
510+
(** [union p x y] returns the partition p with the classes of [x]
511+
and [y] merged. Returns [p] unchanged if either [x] or [y] are
512+
not part of any equivalence class. *)
513+
val union : 'a t -> 'a -> 'a -> 'a t
514+
494515
(** [groups p] returns all partition cells of a partitioning [p] *)
495516
val groups : 'a t -> 'a group seq
496517

lib/graphlib/graphlib_graph.ml

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,123 @@ module Partition = struct
186186
let find x = Option.Monad_infix.(Hashtbl.find comps x >>= find_root) in
187187
{roots; groups; find}
188188

189+
let equiv t x y = Option.equal Equiv.equal (t.find x) (t.find y)
190+
191+
(* The trivial partition with a single class, or zero if elts is empty *)
192+
let trivial elts =
193+
let head = Set.choose elts in
194+
match head with
195+
| None ->
196+
let roots = [||] in
197+
let groups = [||] in
198+
let find _ = None in
199+
{roots; groups; find}
200+
| Some h ->
201+
let roots = Array.create ~len:1 h in
202+
let groups = Array.create ~len:1 (create_set elts) in
203+
let find x = if Set.mem elts x then Some 0 else None in
204+
{roots; groups; find}
205+
206+
(* The discrete partition with one class per element *)
207+
let discrete elts =
208+
let comparator = Set.comparator elts in
209+
let {Comparator.compare} = comparator in
210+
(* Produces a sorted array per the spec *)
211+
let roots = Set.to_array elts in
212+
let groups = elts |> Set.to_array |>
213+
Array.map ~f:(fun x ->
214+
object
215+
method enum = Seq.return x
216+
method mem y = compare x y = 0
217+
end)
218+
in
219+
let find x = Array.binary_search roots ~compare `First_equal_to x in
220+
{roots; groups; find}
221+
222+
(* Takes a partition and a congruence and splits each equivalence class into
223+
elements related by the congruence.
224+
Takes in a comparison function to test for membership in each class.
225+
*)
226+
let refine (type elt) t ~equiv ~cmp =
227+
let module T = Comparator.Make(struct
228+
type t = elt
229+
let compare = cmp
230+
let sexp_of_t = sexp_of_opaque
231+
end) in
232+
let comparator = T.comparator in
233+
let refine_group g =
234+
let rec insert elt output input = match input with
235+
| [] -> Set.singleton ~comparator elt :: output
236+
| group :: input ->
237+
if equiv (Set.choose_exn group) elt
238+
then List.rev_append ((Set.add group elt) :: output) input
239+
else insert elt (group::output) input in
240+
Seq.fold g#enum ~init:[] ~f:(fun groups elt ->
241+
insert elt [] groups) |> List.rev_map ~f:create_set in
242+
let groups_list = Array.fold t.groups ~init:[] ~f:(fun seqs g -> refine_group g @ seqs) in
243+
let groups = Array.of_list groups_list in
244+
Array.sort groups ~cmp:(fun s1 s2 ->
245+
let h1 = Seq.hd_exn s1#enum in
246+
let h2 = Seq.hd_exn s2#enum in
247+
cmp h1 h2);
248+
let roots = Array.map ~f:(fun s -> Seq.hd_exn s#enum) groups in
249+
let find x = Array.binary_search roots ~compare:cmp `First_equal_to x in
250+
{roots; groups; find}
251+
252+
(* Take two elements and combine their classes if both have a class,
253+
do nothing otherwise *)
254+
let union t x y =
255+
(* Assuming i < j,
256+
create a new array a', such that
257+
Array.length a' = Array.length a - 1 and
258+
a'[k] = a[k] when k < i
259+
a'[i] = x
260+
a'[j] = a[j+1]
261+
a'[j+1] = a[j+2]
262+
...
263+
*)
264+
let array_replace a i j x =
265+
assert (i < j && Array.length a > 0);
266+
Array.init (Array.length a - 1)
267+
~f:(fun n -> if n < i then a.(n)
268+
else if n = i then x
269+
else if n < j then a.(n)
270+
else a.(n+1))
271+
in
272+
if equiv t x y then t
273+
else
274+
match t.find x, t.find y with
275+
| None, _ | _,None -> t
276+
| Some i_x, Some i_y ->
277+
let g_x, g_y = t.groups.(i_x), t.groups.(i_y) in
278+
let u_g = object
279+
method enum =
280+
let s_x, s_y = g_x#enum, g_y#enum in
281+
Seq.append s_x s_y
282+
method mem x = g_x#mem x || g_y#mem x
283+
end
284+
in
285+
(* min biased root *)
286+
let i = Int.min i_x i_y in
287+
let j = Int.max i_x i_y in
288+
let u_root = t.roots.(i) in
289+
let roots = array_replace t.roots i j u_root in
290+
let groups = array_replace t.groups i j u_g in
291+
let find x = Option.Monad_infix.(
292+
t.find x >>|
293+
(* By cases: if n < i or i < n < j then it is in one one
294+
of the original classes, otherwise n = i, then one
295+
should return i (as the class still contains these
296+
elements) or n = j, in wich case these elements are now
297+
in class i, or n > j, in which case we must left-shift them *)
298+
fun n -> if n < j then n
299+
else if n = j then i
300+
else n - 1) in
301+
{roots; groups; find}
302+
189303
let nth_group t n = Group.create t.roots.(n) t.groups.(n) n
190304
let groups t = Seq.(range 0 (Array.length t.roots) >>| nth_group t)
191305
let group t x = Option.(t.find x >>| nth_group t)
192-
let equiv t x y = Option.equal Equiv.equal (t.find x) (t.find y)
193306
let number_of_groups t = Array.length t.roots
194307
let of_equiv t i =
195308
if i >= 0 && i < Array.length t.roots

lib/graphlib/graphlib_graph.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ end
206206

207207
module Partition : sig
208208
type 'a t = 'a partition
209+
val trivial : ('a, 'b) Set.t -> 'a t
210+
val discrete : ('a, 'b) Set.t -> 'a t
211+
val refine : 'a t -> equiv:('a -> 'a -> bool) -> cmp:('a -> 'a -> int) -> 'a t
212+
val union : 'a t -> 'a -> 'a -> 'a t
209213
val groups : 'a t -> 'a group Sequence.t
210214
val group : 'a t -> 'a -> 'a group option
211215
val equiv : 'a t -> 'a -> 'a -> bool

lib_test/bap_types/test_graph.ml

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,50 @@ end
564564

565565
module Test_int100 = Construction(Int100)
566566

567+
module Test_partition = struct
567568

569+
module P = Partition
570+
571+
let add x s = Set.add s x
572+
573+
let s = Set.empty Int.comparator
574+
|> add 0
575+
|> add 1
576+
|> add 2
577+
|> add 3
578+
|> add 4
579+
|> add 5
580+
|> add 6
581+
|> add 7
582+
|> add 8
583+
|> add 9
584+
|> add 10
585+
586+
let n = Set.length s
587+
588+
let trivial p _ = assert_bool "failed" (P.number_of_groups p = 1)
589+
590+
let discrete p _ = assert_bool "failed" (P.number_of_groups p = n)
591+
592+
let union p x y _ = assert_bool "failed" (P.equiv p x y)
593+
594+
let refine p equiv _ = assert_bool "failed"
595+
(Seq.for_all (P.groups p)
596+
~f:(fun g ->
597+
let x = Group.top g in
598+
Seq.for_all (Group.enum g) ~f:(fun y -> equiv x y)))
599+
600+
let equiv x y = (x - y) mod 2 = 0
601+
602+
let cmp x y = x - y
603+
604+
let suite () = [
605+
"Trivial invariant" >:: trivial (P.trivial s);
606+
"Discrete invariant" >:: discrete (P.discrete s);
607+
"Union invariant" >:: union (P.union (P.discrete s) 1 2) 1 2;
608+
"Refine invariant" >:: refine (P.refine (P.trivial s) equiv cmp) equiv
609+
]
610+
end
568611

569612
let suite () =
570613
"Graph" >::: [
@@ -573,5 +616,6 @@ let suite () =
573616
let module Test = Test_algo(G) in
574617
Test.suite (sprintf "%d" n));
575618
"Construction" >::: [Test_int100.suite];
576-
"IR" >::: Test_IR.suite ()
619+
"IR" >::: Test_IR.suite ();
620+
"Partition" >::: Test_partition.suite ()
577621
]

0 commit comments

Comments
 (0)