Skip to content

Commit

Permalink
Merge branch 'sparse_ml' of https://github.com/weiya711/sam into spar…
Browse files Browse the repository at this point in the history
…se_ml
  • Loading branch information
lrubens committed Oct 21, 2024
2 parents c6feb2a + 40f1e81 commit fdc1daf
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 14 deletions.
1 change: 1 addition & 0 deletions sam/sim/test/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,7 @@ def check_gold_tensor4_softmax_mask(frosttname, debug_sim, cast, out_crds, out_s
print(out_vals)
B_ref = B_ref.masked_fill(B_ref == 0, -1e9)
gold_ref = torch.nn.functional.softmax(B_ref, dim=3)
gold_ref[gold_ref==1/B_shape[3]] = 0.0
gold_ref = gold_ref.numpy()

print(gold_ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
sub_arr =[]
test_arr = []

soft_out = []

while not done and time_cnt < TIMEOUT:
if len(in_ref_V) > 0:
fiberlookup_Vi_35.set_in_ref(in_ref_V.pop(0))
Expand Down Expand Up @@ -225,6 +227,8 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
mul_46.set_in1(arrayvals_Q_47.out_val())
mul_46.set_in2(arrayvals_K_48.out_val())
reduce_45.set_in_val(mul_46.out_val())
soft_out.append(reduce_45.out_val())
print("Val: ", remove_emptystr(soft_out))
# QK_T / sqrt(d_k)
scalar_mul.set_in1(reduce_45.out_val())
# print("scalar", scalar_mul.out_val())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
crddrop_ij = CrdDrop(debug=debug_sim, statistics=report_stats, back_en=backpressure, depth=int(depth))
scalar_mul = ScalarMult(in2 = 1.0 / sqrt(Q_shape[3]), debug=debug_sim, statistics=report_stats, back_en=backpressure, depth=int(depth))
comp_drop_1 = ValDropper(debug=debug_sim, drop_refs=True, statistics=report_stats)
tril = Tril(debug=debug_sim, statistics=report_stats, back_en=backpressure, depth=int(depth))
# tril = Tril(debug=debug_sim, statistics=report_stats, back_en=backpressure, depth=int(depth))
tril = LowerTriangular(dimension=2, debug=debug_sim, statistics=report_stats, back_en=backpressure, depth=int(depth))
crd = CrdHold(debug=debug_sim)
drop_1 = CrdDrop(debug=debug_sim, statistics=report_stats)
drop_2 = CrdDrop(debug=debug_sim, statistics=report_stats)
# drop = Dropout(debug=debug_sim, drop_prob=0.5, statistics=report_stats, back_en=backpressure, depth=int(depth))
in_ref_V = [0, 'D']
in_ref_Q = [0, 'D']
Expand All @@ -174,6 +178,8 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
time_cnt = 0

soft_out = []
mul = []
mul0 = []

while not done and time_cnt < TIMEOUT:
if len(in_ref_V) > 0:
Expand Down Expand Up @@ -223,19 +229,40 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
mul_46.set_in2(arrayvals_K_48.out_val())
reduce_45.set_in_val(mul_46.out_val())

tril.set_inner_crd(fiberlookup_Kl_416.out_crd())
tril.set_outer_crd(fiberlookup_Qk_420.out_crd())
soft_out.append(reduce_45.out_val())
print("Val: ", remove_emptystr(soft_out))
# mul.append(arrayvals_K_48.out_val())
mul.append(intersecti2_424.out_crd())
print("mul", remove_emptystr(mul))
mul0.append(arrayvals_Q_47.out_val())
print("mul0", remove_emptystr(mul0))
# tril.set_inner_crd(fiberlookup_Kl_416.out_crd())
# tril.set_outer_crd(fiberlookup_Qk_420.out_crd())
# tril.set_inner_ref(reduce_45.out_val())
# tril.set_inner_ref(reduce_45.out_val())
tril.set_inner_ref(fiberlookup_Kl_416.out_ref())
tril.set_crd1(intersectj2_421.out_crd())
tril.set_crd0(intersecti2_424.out_crd())
# tril.set_inner_ref(fiberlookup_Kl_416.out_ref())
# tril.set_crd1(intersectj2_421.out_crd())
# tril.set_crd0(intersecti2_424.out_crd())

crd.set_outer_crd(fiberlookup_Qk_420.out_crd())
crd.set_inner_crd(fiberlookup_Kl_416.out_crd())

tril.set_inner_crd(crd.out_crd_inner())
tril.set_outer_crd(crd.out_crd_outer())
tril.set_inner_ref(reduce_45.out_val())

drop_1.set_inner_crd(tril.out_crd(1))
drop_1.set_outer_crd(intersectj2_421.out_crd())

drop_2.set_inner_crd(drop_1.out_crd_outer())
drop_2.set_outer_crd(intersecti2_424.out_crd())

# QK_T / sqrt(d_k)
scalar_mul.set_in1(tril.out_ref())
# scalar_mul.set_in1(reduce_45.out_val())

# TODO: Replace with softmax superblock
repsiggen_l1_414.set_istream(tril.out_crd_inner())
repsiggen_l1_414.set_istream(tril.out_crd(0))
maxreduce_434.set_in_val(scalar_mul.out_val())
repeat_QKl_437.set_in_repsig(repsiggen_l1_414.out_repsig())
repeat_QKl_437.set_in_ref(maxreduce_434.out_val())
Expand All @@ -252,11 +279,9 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
comp_drop_1.set_val(div_432.out_val())
# comp_drop_1.set_crd(fiberlookup_Kl_416.out_crd())
# comp_drop_1.set_ref(fiberlookup_Kl_416.out_ref())
comp_drop_1.set_crd(tril.out_crd_inner())
comp_drop_1.set_crd(tril.out_crd(0))
comp_drop_1.set_ref(tril.out_ref())

soft_out.append(tril.out_crd_inner())
print("Val: ", remove_emptystr(soft_out))

intersectl_23.set_in1(fiberlookup_Vl_25.out_ref(), fiberlookup_Vl_25.out_crd())
intersectl_23.set_in2(comp_drop_1.out_ref(), comp_drop_1.out_crd())
Expand All @@ -270,11 +295,13 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
mul_15.set_in2(arrayvals_V_17.out_val())

# crddrop_kl.set_outer_crd(fiberlookup_Qk_420.out_crd())
crddrop_kl.set_outer_crd(tril.out_crd_outer())
crddrop_kl.set_outer_crd(tril.out_crd(1))
crddrop_kl.set_inner_crd(intersectl_23.out_crd())
crddrop_jk.set_outer_crd(intersectj2_421.out_crd())
# crddrop_jk.set_outer_crd(intersectj2_421.out_crd())
crddrop_jk.set_outer_crd(drop_1.out_crd_outer())
crddrop_jk.set_inner_crd(crddrop_kl.out_crd_outer())
crddrop_ij.set_outer_crd(intersecti2_424.out_crd())
# crddrop_ij.set_outer_crd(intersecti2_424.out_crd())
crddrop_ij.set_outer_crd(drop_2.out_crd_outer())
crddrop_ij.set_inner_crd(crddrop_jk.out_crd_outer())

spaccumulator1_5.set_in_crd0(fiberlookup_Vm_22.out_crd())
Expand Down Expand Up @@ -315,8 +342,12 @@ def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gol
arrayvals_K_48.update()
mul_46.update()
reduce_45.update()
crd.update()
tril.update()
drop_1.update()
drop_2.update()
scalar_mul.update()
repsiggen_l1_414.update()
maxreduce_434.update()
repeat_QKl_437.update()
add_433.update()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
@pytest.mark.frostt
def test_tensor4_multihead_attention_ijklm(samBench, frosttname, cast, check_gold, debug_sim, backpressure, depth, report_stats, fill=0):
test_name = "tensor4_fused_mul_T1"
test_name = "tensor4_mha1"
Q_dirname = os.path.join(formatted_dir, frosttname, test_name)
Q_shape_filename = os.path.join(Q_dirname, "tensor_Q_mode_shape")
Q_shape = read_inputs(Q_shape_filename)
Expand Down

0 comments on commit fdc1daf

Please sign in to comment.