Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MzeroMiko committed Feb 25, 2024
1 parent bf16b78 commit 6529bea
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions test_selective_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,17 @@ def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
ys = torch.einsum("lbgn,lbgdn->lbgd", Cs, hs)
return ys, hs


dtype = torch.float32
# dtype = torch.float16
inp_dtype = us.dtype
has_D = Ds is not None
if chunksize < 1:
chunksize = Bs.shape[-1]

dts = dts.float()
dts = dts.to(dtype)
if delta_bias is not None:
dts = dts + delta_bias.view(1, -1, 1).float()
dts = dts + delta_bias.view(1, -1, 1).to(dtype)
if delta_softplus:
dts = torch.nn.functional.softplus(dts)

Expand All @@ -66,16 +69,16 @@ def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
if len(Cs.shape) == 3:
Cs = Cs.unsqueeze(1)
B, G, N, L = Bs.shape
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()
As = As.view(G, -1, N).float()
Bs = Bs.permute(3, 0, 1, 2).float()
Cs = Cs.permute(3, 0, 1, 2).float()
Ds = Ds.view(G, -1).float() if has_D else None
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype)
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype)
As = As.view(G, -1, N).to(dtype)
Bs = Bs.permute(3, 0, 1, 2).to(dtype)
Cs = Cs.permute(3, 0, 1, 2).to(dtype)
Ds = Ds.view(G, -1).to(dtype) if has_D else None
D = As.shape[1]

oys = []
hprefix = us.new_zeros((B, G, D, N), dtype=torch.float)
hprefix = us.new_zeros((B, G, D, N), dtype=dtype)
for i in range(0, L, chunksize):
ys, hs = selective_scan_chunk(
us[i:i + chunksize], dts[i:i + chunksize],
Expand All @@ -90,7 +93,7 @@ def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):
oys = oys.permute(1, 2, 3, 0).view(B, -1, L)

# return oys, hprefix.view(B, G * D, N)
return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix.view(B, G * D, N))
return oys.to(inp_dtype) if not return_last_state else (oys.to(inp_dtype), hprefix.view(B, G * D, N).float())


class SelectiveScanEasy(torch.autograd.Function):
Expand All @@ -109,10 +112,11 @@ def save_for_backward(ctx, *args):
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64):
has_D = Ds is not None
dtype = torch.float32

dts = dts.float()
dts = dts.to(dtype)
if delta_bias is not None:
dts = dts + delta_bias.view(1, -1, 1).float()
dts = dts + delta_bias.view(1, -1, 1).to(dtype)
if delta_softplus:
dts = torch.nn.functional.softplus(dts)

Expand All @@ -123,12 +127,12 @@ def forward(ctx, us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False,
if C_squeeze:
Cs = Cs.unsqueeze(1)
B, G, N, L = Bs.shape
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()
As = As.view(G, -1, N).float()
Bs = Bs.permute(3, 0, 1, 2).float()
Cs = Cs.permute(3, 0, 1, 2).float()
Ds = Ds.view(G, -1).float() if has_D else None
us = us.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype)
dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).to(dtype)
As = As.view(G, -1, N).to(dtype)
Bs = Bs.permute(3, 0, 1, 2).to(dtype)
Cs = Cs.permute(3, 0, 1, 2).to(dtype)
Ds = Ds.view(G, -1).to(dtype) if has_D else None
D = As.shape[1]

ctx.shape = (B, G, D, N, L)
Expand Down Expand Up @@ -191,6 +195,7 @@ def rev_comsum(x):
return (x - cum_sum + cum_sum[-1:None])

if DEBUG:
dtype = torch.float32
us = us.requires_grad_()
dts = dts.requires_grad_()
As = As.requires_grad_()
Expand Down Expand Up @@ -307,9 +312,9 @@ def rev_comsum(x):
print("3", (torch.autograd.grad(_oys, tmp_fwd_dtBus[chunks.index(i)], doys, create_graph=True, allow_unused=True)[0] - ddtBus).abs().sum())

if DEBUG:
tmp_a = torch.randn((L, B, G, D, N)).float().cuda().requires_grad_()
tmp_a = torch.randn((L, B, G, D, N)).to(dtype).cuda().requires_grad_()
tmp_b = torch.cumsum(tmp_a, dim=0)
tmp_c = torch.randn((L, B, G, D, N)).float().cuda()
tmp_c = torch.randn((L, B, G, D, N)).to(dtype).cuda()
print("ex.0", (torch.autograd.grad(tmp_b, tmp_a, tmp_c, create_graph=True, allow_unused=True)[0] - rev_comsum(tmp_c)).abs().sum())

drAts_dtBus_div_rAts = d_dtBus_div_rAts * (-dtBus_div_rAts / rAts)
Expand Down

0 comments on commit 6529bea

Please sign in to comment.