Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorizer refactor heuristic for select_and_join #3449

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 81 additions & 49 deletions backend/cfg/vectorize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,11 @@ module Block : sig

val find : t -> Instruction.Id.t -> Instruction.t

(** [find_last_instruction t instrs] returns instruction [i]
from [instrs] such that [i] appears after
all other instructions from [instrs] according to the order of instructions
in this basic block. Raises if [instrs] is empty. *)
val find_last_instruction : t -> Instruction.Id.t list -> Instruction.t
(** [find_last_instruction_id_and_pos group block] returns scalar instruction [i] from
[group] and its position [pos] such that [i] appears after all other instructions
from [group] according to the order of instructions in this basic [block]. *)
val find_last_instruction_id_and_pos :
t -> Instruction.t list -> Instruction.Id.t * int

val get_live_regs_before_terminator : t -> State.live_regs

Expand Down Expand Up @@ -417,28 +417,29 @@ end = struct
let get_live_regs_before_terminator t =
State.liveness t.state t.block.terminator.id

let find_last_instruction t instructions =
let instruction_set = Instruction.Id.Set.of_list instructions in
let terminator = terminator t in
if Instruction.Id.Set.mem (Instruction.id terminator) instruction_set
then terminator
else
let body = t.block.body in
let rec find_last cell_option =
match cell_option with
| None ->
Misc.fatal_errorf "Vectorizer.find_last_instruction in block %a"
Label.print t.block.start ()
| Some cell ->
let current_instruction = Instruction.basic (DLL.value cell) in
let current_instruction_id = Instruction.id current_instruction in
if Instruction.Id.Set.exists
(Instruction.Id.equal current_instruction_id)
instruction_set
then current_instruction
else find_last (DLL.prev cell)
in
find_last (DLL.last_cell body)
let find_last_instruction_id_and_pos t instructions =
let get instr =
let id = Instruction.id instr in
let pos = pos t id in
id, pos
in
let rec loop instructions last_id last_pos =
match instructions with
| [] -> last_id, last_pos
| hd :: tl ->
let hd_id, hd_pos = get hd in
if Int.compare hd_pos last_pos > 0
then loop tl hd_id hd_pos
else loop tl last_id last_pos
in
let loop_non_empty instructions =
match instructions with
| [] -> assert false
| hd :: tl ->
let last_id, last_pos = get hd in
loop tl last_id last_pos
in
loop_non_empty instructions
end

(* CR-someday gyorsh: Dependencies computed below can be used for other
Expand Down Expand Up @@ -2328,12 +2329,16 @@ end = struct

type t =
{ groups : Group.t Instruction.Id.Map.t;
(* [all_instructions] is all the scalar instructions in the computations.
It is an optimization to cache this value here. It is used for ruling
out computations that are invalid or not implementable, and to estimate
cost/benefit of vectorized computations. *)
all_scalar_instructions : Instruction.Id.Set.t;
new_positions : int Instruction.Id.Map.t
(** [all_scalar_instructions] is all the scalar instructions in the
computations. It is an optimization to cache this value here. It is used
for ruling out computations that are invalid or not implementable, and to
estimate cost/benefit of vectorized computations. *)
new_positions : int Instruction.Id.Map.t;
(** [new_positions] is used for validation. *)
last_pos : int option
(** [last_pos] the position in the block body of the last scalar instruction, used
for heuristics. [None] for empty computations. *)
}

let num_groups t = Instruction.Id.Map.cardinal t.groups
Expand Down Expand Up @@ -2575,19 +2580,31 @@ end = struct
&& respects_register_order_constraints t deps
&& not (is_dependency_of_outside_body t block deps)

(** The key is the last instruction id, for now. This is the place
where the vectorized intructions will be inserted. *)
let get_key block instruction_ids =
let last_instruction = Block.find_last_instruction block instruction_ids in
Instruction.id last_instruction
(** The key is the last instruction id, for now. This is the place in the body of the
block where the vectorized instructions will be inserted. *)
let get_key group block =
let id, _pos =
Block.find_last_instruction_id_and_pos block
(Group.scalar_instructions group)
in
id

let get_last_pos group block =
let _id, pos =
Block.find_last_instruction_id_and_pos block
(Group.scalar_instructions group)
in
pos

(** Returns the dependencies of arguments at position [arg_i]
of each instruction in [instruction_ids]. Returns None if
one of the instruction's dependencies is None for [arg_i]. *)
let get_deps deps ~arg_i instruction_ids =
let get_deps deps ~arg_i group =
Misc.Stdlib.List.map_option
(Dependencies.get_direct_dependency_of_arg deps ~arg_i)
instruction_ids
(fun instruction ->
let id = Instruction.id instruction in
Dependencies.get_direct_dependency_of_arg deps ~arg_i id)
(Group.scalar_instructions group)

let all_instructions map =
Instruction.Id.Map.fold
Expand Down Expand Up @@ -2617,7 +2634,8 @@ end = struct
let empty =
{ groups = Instruction.Id.Map.empty;
all_scalar_instructions = Instruction.Id.Set.empty;
new_positions = Instruction.Id.Map.empty
new_positions = Instruction.Id.Map.empty;
last_pos = None
}

(* CR gyorsh: if same instruction belongs to two groups, is it handled
Expand All @@ -2632,10 +2650,7 @@ end = struct
match group with
| None -> None
| Some (group : Group.t) -> (
let instruction_ids =
Group.scalar_instructions group |> List.map Instruction.id
in
let key = get_key block instruction_ids in
let key = get_key group block in
(* Is there another group with the same key already in the tree? If the
key instruction of the group is already in another group, and the other
group is different from this group, we won't vectorize this for
Expand All @@ -2657,7 +2672,7 @@ end = struct
(* CR-someday gyorsh: refer directly to [Reg.t] instead of
positional [arg_i]. Currently, the code assumes that address
args are always at the end. *)
match get_deps deps ~arg_i instruction_ids with
match get_deps deps ~arg_i group with
| None ->
(* At least one of the arguments has a dependency outside the
block. Currently, not supported. *)
Expand Down Expand Up @@ -2689,14 +2704,21 @@ end = struct
let t =
{ groups = map;
all_scalar_instructions = all_instructions map;
new_positions = new_positions map block
new_positions = new_positions map block;
last_pos = Some (get_last_pos root block)
}
in
State.dump_debug (Block.state block)
"Computation.from_seed build finished\n%a\n" (dump ~block) t;
assert (seed_address_does_not_depend_on_tree t block deps seed);
if is_valid t block deps then Some t else None

let max_pos o1 o2 =
match o1, o2 with
| Some p1, Some p2 -> Some (Int.max p1 p2)
| None, None -> None
| (Some _ as res), None | None, (Some _ as res) -> res

let join t1 t2 =
{ groups =
Instruction.Id.Map.union
Expand All @@ -2722,7 +2744,8 @@ end = struct
pos2=%d"
Instruction.Id.print key pos1 pos2;
Some pos1)
t1.new_positions t2.new_positions
t1.new_positions t2.new_positions;
last_pos = max_pos t1.last_pos t2.last_pos
}

(** address registers and vectorizable registers of [t] and [t'] are compatible, i.e.,
Expand Down Expand Up @@ -2784,7 +2807,16 @@ end = struct
| trees ->
(* sort by cost, ascending *)
let compare_cost t1 t2 = Int.compare (cost t1) (cost t2) in
let trees = List.sort compare_cost trees in
let compare_cost_and_last_pos t1 t2 =
let c = compare_cost t1 t2 in
if not (c = 0)
then c
else
(* heuristic to prioritize groups that appear later, it reduces the
chance they are a dependency of the rest of the body. *)
Int.neg (Option.compare Int.compare t1.last_pos t2.last_pos)
in
let trees = List.sort compare_cost_and_last_pos trees in
let rec loop trees acc =
match trees with
| [] -> acc
Expand Down
Loading