Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash GPT2 #93

Merged
merged 9 commits into from
Dec 5, 2023
Prev Previous commit
Next Next commit
Forward
  • Loading branch information
tgaddair committed Dec 1, 2023
commit 17505370b488354317543a81c61b12dc69040348
53 changes: 53 additions & 0 deletions server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,59 @@ def forward(
input_lengths,
max_s,
):
# Prepare query, key, and value
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = (key, value)
else:
present = None

# Apply Flash Attention Forward
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
# Flash Attention forward pass
attn_output = self._flash_attention_forward(
query, key, value, attention_mask, query.size(-2), self.attn_dropout.p, softmax_scale=None
)

# Merge heads and project back to hidden size
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)

return outputs






qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)

Expand Down
Loading