Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#55 from ajfadam/main
Browse files Browse the repository at this point in the history
remove numpy dependency
  • Loading branch information
tridao committed Oct 6, 2022
2 parents 88dc204 + 4e38df0 commit 8dd52b0
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions flash_attn/bert_padding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py

import numpy as np

import torch
import torch.nn.functional as F

Expand All @@ -15,7 +13,7 @@ def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape)
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(rearrange(input, 'b ... -> b (...)'), 0,
Expand Down Expand Up @@ -71,7 +69,7 @@ def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = np.prod(other_shape)
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output = input[indices]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
Expand Down

0 comments on commit 8dd52b0

Please sign in to comment.