Skip to content

Commit

Permalink
Fixed blocked tensor generator in onyx
Browse files Browse the repository at this point in the history
  • Loading branch information
lrubens committed Oct 21, 2024
1 parent 12a7fc2 commit c6feb2a
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions sam/onyx/generate_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def _create_matrix(self, value_cap=int(math.pow(2, 8)) - 1):
# print(self.array[...,:self.shape[-2]:self.block_size,:self.shape[-1]:self.block_size])

def _create_fiber_tree(self):
self.fiber_tree = FiberTree(tensor=self.array if self.block_naive else self.array[...,
self.shape[-2]:self.block_size,:self.shape[-1]:self.block_size])
# self.fiber_tree = FiberTree(tensor=self.array if self.block_naive else self.array[...,
# self.shape[-2]:self.block_size,:self.shape[-1]:self.block_size])
self.fiber_tree = FiberTree(tensor=self.array)
self.tmp_fiber_tree = FiberTree(tensor=self.array[..., :self.shape[-2]:self.block_size,
:self.shape[-1]:self.block_size])

Expand Down Expand Up @@ -170,12 +171,18 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
# In CSF format, need to iteratively create seg/coord arrays
tmp_lvl_list = []
small_tmp_lvl_list = []
# if self.block_size == 1:
tmp_lvl_list.append(self.fiber_tree.get_root())
# else:
# tmp_lvl_list.append(self.tmp_fiber_tree.get_root())
small_tmp_lvl_list.append(self.tmp_fiber_tree.get_root())
# print(small_tmp_lvl_list)
# print(tmp_lvl_list)

seg_arr, coord_arr = None, None
if self.block_size > 1:
seg_arr, coord_array = self._dump_csf(small_tmp_lvl_list)
seg_arr, coord_arr = self._dump_csf(small_tmp_lvl_list)
# print(seg_arr, coord_arr)
else:
seg_arr, coord_arr = self._dump_csf(tmp_lvl_list)
if glb_override:
Expand All @@ -184,6 +191,7 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
hex=print_hex)
else:
print(self.mode_ordering)
# print(seg_arr, coord_arr)
self.write_array(seg_arr, name=f"tensor_{self.name}_mode_{self.mode_ordering[0]}_seg", dump_dir=use_dir,
hex=print_hex)
self.write_array(coord_arr, name=f"tensor_{self.name}_mode_{self.mode_ordering[0]}_crd",
Expand All @@ -192,10 +200,12 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
at_vals = False
# TODO: Might need a try catch
i = 1

tmp_lst = tmp_lvl_list if self.block_size == 1 else small_tmp_lvl_list
while at_vals is False:
# Make the next level of fibers - basically BFS but segmented across depth of tree
next_tmp_lvl_list = []
for fib in tmp_lvl_list:
for fib in tmp_lst:
crd_payloads_tmp = fib.get_coord_payloads()
if type(crd_payloads_tmp[0][1]) is not FiberTreeFiber:
at_vals = True
Expand All @@ -204,18 +214,34 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
else:
for crd, pld in crd_payloads_tmp:
next_tmp_lvl_list.append(pld)
tmp_lvl_list = next_tmp_lvl_list
tmp_lst = next_tmp_lvl_list
if at_vals:
# If at vals, we don't need to dump csf, we have the level
if glb_override:
lines = [len(tmp_lvl_list), *tmp_lvl_list]
lines = [len(tmp_lst), *tmp_lst]
# self.write_array(tmp_lvl_list, name=f"tensor_{self.name}_mode_vals" dump_dir=use_dir)
self.write_array(lines, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir, hex=print_hex)
else:
self.write_array(tmp_lvl_list, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir,
reached_full_vals = False
if self.block_size > 1:
tmp_lst = tmp_lvl_list
# Retrieve values from full tensor
while not reached_full_vals:
next_tmp_lvl_list = []
for fib in tmp_lst:
crd_payloads_tmp = fib.get_coord_payloads()
if type(crd_payloads_tmp[0][1]) is not FiberTreeFiber:
reached_full_vals = True
for crd, pld in crd_payloads_tmp:
next_tmp_lvl_list.append(pld)
else:
for crd, pld in crd_payloads_tmp:
next_tmp_lvl_list.append(pld)
tmp_lst = next_tmp_lvl_list
self.write_array(tmp_lst, name=f"tensor_{self.name}_mode_vals", dump_dir=use_dir,
hex=print_hex)
else:
seg_arr, coord_arr = self._dump_csf(tmp_lvl_list)
seg_arr, coord_arr = self._dump_csf(tmp_lst)
if glb_override:
lines = [len(seg_arr), *seg_arr, len(coord_arr), *coord_arr]
self.write_array(lines, name=f"tensor_{self.name}_mode_{self.mode_ordering[i]}",
Expand Down Expand Up @@ -267,7 +293,8 @@ def dump_outputs(self, format=None, tpose=False, dump_shape=True, glb_override=F
hex=print_hex)

if dump_shape:
self.write_array(self.array.shape, name=f"tensor_{self.name}_mode_shape", dump_dir=use_dir, hex=print_hex)
final_shape = [x // self.block_size for x in self.array.shape]
self.write_array(final_shape, name=f"tensor_{self.name}_mode_shape", dump_dir=use_dir, hex=print_hex)

# Transpose it back
if tpose is True:
Expand Down

0 comments on commit c6feb2a

Please sign in to comment.