Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Mar 20, 2023
1 parent 4c2c895 commit 9b77639
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/llm/src/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def forward(self,

query, key, value = qkv.chunk(3, dim=2)

query_padding_mask = None
if key_padding_mask is not None:
query_padding_mask = key_padding_mask[:, -query.size(1):]

Expand Down
4 changes: 2 additions & 2 deletions examples/llm/src/models/layers/gpt_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""GPT Blocks used for the GPT Model."""

from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -52,7 +52,7 @@ def forward(
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
) -> torch.Tensor:
) -> tuple[torch.Tensor, Union[Tuple[torch.Tensor], None]]:
a = self.ln_1(x)
b, _, past_key_value = self.attn(a,
past_key_value=past_key_value,
Expand Down
4 changes: 2 additions & 2 deletions examples/llm/src/models/mosaic_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ def forward(
attn_bias = self._attn_bias(device=x.device, dtype=x.dtype)

for b_idx, block in enumerate(self.transformer.blocks): # type: ignore
past_key_value = past_key_value[
past_key_value = past_key_values[
b_idx] if past_key_values is not None else None
x, past_key_value = block(x,
past_key_value=past_key_value,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=self.is_causal)
if past_key_values is not None:
past_key_value[b_idx] = past_key_value
past_key_values[b_idx] = past_key_value

x = self.transformer.ln_f(x) # type: ignore
# output embedding weight tied to input embedding
Expand Down

0 comments on commit 9b77639

Please sign in to comment.