Skip to content

Commit

Permalink
Revert "Call _sdp_attention in nn.functional.mha (pytorch#89470)"
Browse files Browse the repository at this point in the history
This reverts commit 4d7ec30.

Reverted pytorch#89470 on behalf of https://github.com/jeanschmidt due to breaking internal builds
  • Loading branch information
pytorchmergebot committed Nov 30, 2022
1 parent 618a585 commit f1415b8
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 53 deletions.
9 changes: 5 additions & 4 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>


#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
Expand Down Expand Up @@ -740,10 +741,10 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
}
auto attn_mask = attn_mask_;
// Naive, composite implementation defined here.
const auto embed_size = query_.size(-1);

// Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math
const auto embed_size = SymFloat(query_.sym_size(-1));
const auto scaling_factor = embed_size.sqrt().sqrt();
const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size)));
const auto query = query_ / scaling_factor;
if (is_causal) {
TORCH_CHECK(!attn_mask.has_value(),
Expand All @@ -752,8 +753,8 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");

// Replace attn_mask with causal mask; lower triangular elements take part in attention.
const auto L = query.sym_size(-2), S = key.sym_size(-2);
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
const auto L = query.size(-2), S = key.size(-2);
attn_mask = at::ones({L, S}, query.options().dtype(at::kBool)).tril();
}
if (attn_mask.has_value()) {
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
Expand Down
26 changes: 2 additions & 24 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ inline bool check_tensor_dtype(
allowed_dtypes.end()))) {
TORCH_CHECK(
!debug,
"Expected query, key and value to all be of dtype: {",
c10::Join(", ", allowed_dtypes), "}. Got ",
"Query dtype: ",
"Expected query, key and value to be of dtype float16 or bfloat16 but got Query dtype: ",
params.query.dtype(),
", Key dtype: ",
params.key.dtype(),
Expand Down Expand Up @@ -164,25 +162,6 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
return true;
}

inline bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
const int64_t query_size_last = params.query.size(-1);
if (!(query_size_last == params.key.size(-1) &&
query_size_last == params.value.size(-1) && query_size_last >= 8)) {
TORCH_CHECK(
!debug,
"Mem efficient attention requires last dimension of inputs to be >= 8.",
"Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
params.key.size(-1),
", Value.size(-1): ",
params.value.size(-1),
" instead.");
return false;
}
return true;
}

inline bool check_runtime_disabled_flash(sdp_params params, bool debug) {
// We check the global context to see if user has explicitly turned of flash
// sdp kernels
Expand Down Expand Up @@ -280,14 +259,13 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
at::kHalf, at::kFloat, at::kBFloat16};

// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints{{
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints{{
check_gpu_sm50_or_greater,
check_runtime_disabled_mem_efficient,
check_requires_grad_and_nested,
check_for_attn_weights,
check_tensor_shapes,
check_for_attn_mask,
check_head_dim_size_mem_efficient,
check_for_seq_len_1_nested_tensor,
check_for_non_zero_dropout}};
for (auto& constraint : constraints) {
Expand Down
10 changes: 0 additions & 10 deletions c10/core/SymFloat.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <c10/core/SymFloat.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
#include <cmath>
#include <utility>

namespace c10 {
Expand Down Expand Up @@ -71,15 +70,6 @@ std::ostream& operator<<(std::ostream& os, const SymFloat& s) {
return os;
}

SymFloat SymFloat::sqrt() const {
if (!is_symbolic()) {
return SymFloat(std::sqrt(data_));
}
auto other = SymFloat(-0.5);
auto res = normalize_symfloats(*this, other);
return SymFloat(res[0]->pow(res[1]));
}

double SymFloat::guard_float(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;
Expand Down
3 changes: 0 additions & 3 deletions c10/core/SymFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class C10_API SymFloat {
SymFloat operator*(const SymFloat&) const;
SymFloat operator/(const SymFloat&) const;

// Need guidance on where to put this code
SymFloat sqrt() const;

// Insert a guard for the float to be its concrete value, and then return
// that value. This operation always works, even if the float is symbolic,
// so long as we know what the underlying value is. Don't blindly put this
Expand Down
1 change: 0 additions & 1 deletion test/onnx/test_models_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def forward(self, images, features: Mapping[str, torch.Tensor]):
)

@skipScriptTest() # TODO: #75625
@skipIfUnsupportedMinOpsetVersion(20)
def test_transformer_encoder(self):
class MyModule(torch.nn.Module):
def __init__(self, ninp, nhead, nhid, dropout, nlayers):
Expand Down
21 changes: 10 additions & 11 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5173,20 +5173,19 @@ def multi_head_attention_forward(
# (deep breath) calculate attention and out projection
#

B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
if attn_mask is not None:
if attn_mask.size(0) == 1:
attn_mask = attn_mask.unsqueeze(0)
else:
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)

q = q.view(bsz, num_heads, tgt_len, head_dim)
k = k.view(bsz, num_heads, src_len, head_dim)
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)

attn_output, attn_output_weights = _scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, need_weights, False)
attn_output = attn_output.transpose(1, 2).transpose(0, 1).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = torch.bmm(attn_output_weights, v)

attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

Expand Down

0 comments on commit f1415b8

Please sign in to comment.