Skip to content

Single location to update optional args for all attentions #8128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type
from typing import Any, Dict, Optional, Tuple, Type, TypedDict

import torch
import torch.nn as nn
Expand All @@ -8,6 +8,15 @@
from executorch.examples.models.llama.rope import Rope


class ForwardOptions(TypedDict, total=False):
"""Optional parameters for `Attention.forward` (compative with Python 3.10 and plus)."""

mask: Optional[torch.Tensor]
input_pos: Optional[torch.Tensor]
in_cache_state: Optional[Any]
out_cache_state: Optional[Any]


class Attention(nn.Module, ABC):
"""Abstract base class for attention mechanisms with unified interface."""

Expand All @@ -17,19 +26,14 @@ def forward(
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
in_cache_state: Optional[Any] = None,
out_cache_state: Optional[Any] = None,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
"""Forward pass for attention mechanism.

Args:
x: Input tensor of shape (batch_size, seq_len, dim)
freqs_cos, freqs_sin: Rotary position embedding frequencies
mask: Optional attention mask
input_pos: Positions for KV cache updates
in_cache_state/out_cache_state: Cache states
ForwardOptions: grouped optional args

Returns:
Tuple of (output tensor, updated cache state)
Expand Down Expand Up @@ -209,11 +213,9 @@ def forward(
x: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
in_cache_state: Optional[Any] = None,
out_cache_state: Optional[Any] = None,
**kwargs: ForwardOptions,
) -> Tuple[torch.Tensor, Optional[Any]]:
input_pos = kwargs.get("input_pos")
bsz, seqlen, _ = x.shape

# QKV
Expand Down
Loading