Skip to content

Commit 910bbfe

Browse files
authored
Support 3D attention mask (#5887)
Support 3D attention mask with shape (batch_size, sequence_length, all_sequence_length)
1 parent cc6e8fb commit 910bbfe

File tree

11 files changed

+886
-596
lines changed

11 files changed

+886
-596
lines changed

onnxruntime/contrib_ops/cpu/bert/attention.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
5959
// input : (batch_size, sequence_length, hidden_size)
6060
// weights : (hidden_size, 3 * hidden_size)
6161
// bias : (3 * hidden_size)
62-
// mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
62+
// mask_index : nullptr, (batch_size), (2 * batch_size),
63+
// or (batch_size, 1), (1, 1)
64+
// or (batch_size, past_sequence_length + sequence_length)
65+
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
6366
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
6467

6568
const auto& dims = input_shape.GetDims();
@@ -136,8 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
136139
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)");
137140
}
138141
}
142+
} else if (mask_dims.size() == 3) {
143+
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
144+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
145+
}
139146
} else {
140-
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 or 2 dimensions, got ",
147+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
141148
mask_dims.size());
142149
}
143150
}

onnxruntime/contrib_ops/cpu/bert/attention_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class AttentionBase {
2626
int sequence_length,
2727
int& past_sequence_length) const;
2828

29-
int num_heads_; // number of attention heads
30-
bool is_unidirectional_; // whether every token can only attend to previous tokens.
29+
int num_heads_; // number of attention heads
30+
bool is_unidirectional_; // whether every token can only attend to previous tokens.
3131
};
3232

3333
} // namespace contrib

onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class AttentionCPUBase : public AttentionBase {
8989
const T* K, // k data. Its size is BxNxSxH
9090
const int32_t* mask_index, // mask index. nullptr if no mask or its size is B
9191
const std::vector<int64_t>* mask_index_dims, // mask index shape
92-
T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise
92+
T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr, otherwise its shape is BxSxS*
9393
int batch_size, // batch size of self-attention
9494
int sequence_length, // sequence length of self-attention
9595
int past_sequence_length, // sequence length of past state

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,31 @@ void PrepareMask(const int32_t* mask_index,
7272
// mask_data has been filled with 0, and its shape is BxSxS*
7373
T* p_mask = mask_data;
7474

75+
// For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any.
76+
if (nullptr != mask_index_dims && mask_index_dims->size() == 3) {
77+
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
78+
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
79+
}
80+
81+
if (is_unidirectional) {
82+
for (int b_i = 0; b_i < batch_size; b_i++) {
83+
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
84+
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
85+
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
86+
}
87+
}
88+
p_mask += sequence_length * all_sequence_length;
89+
}
90+
}
91+
92+
return;
93+
}
94+
7595
bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2);
7696
bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast<int>(mask_index_dims->at(0)) == 2 * batch_size);
7797

7898
for (int b_i = 0; b_i < batch_size; b_i++) {
7999
// TODO: mask_index can be used in softmax to save some calculation.
80-
81100
if (nullptr != mask_index) {
82101
if (is_raw_attention_mask) {
83102
// Raw attention mask has value 0 or 1. Here we convert 0 to -10000.0, and 1 to 0.0.
@@ -120,7 +139,6 @@ void PrepareMask(const int32_t* mask_index,
120139

121140
p_mask += sequence_length * all_sequence_length;
122141
}
123-
124142
}
125143

126144
// Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH

0 commit comments

Comments
 (0)