Skip to content

Pass ForwardOptions from top level module and also return any relevant state as output #8186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
return self.wo(output)
return self.wo(output), None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is None a place holder or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to support additional output from each attention layer that'll be aggregated and returned by the top level transformer. To avoid disruptions if this is `None, the top level transformer still return a single logit output, otherwise it returns a tuple.

This was actually added in the previous PR that introduced the abstract attention class:

) -> Tuple[torch.Tensor, Optional[Any]]:
.


# grouped multiquery attention: expand out keys and values
k = k.repeat_interleave(self.n_rep, dim=1)
Expand All @@ -252,4 +252,4 @@ def forward(

output = self.wo(output)

return output
return output, None
8 changes: 6 additions & 2 deletions examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _model_call(self, inps):
result_logits = []
for pos in range(inps.shape[-1]):
pos_tensor = torch.tensor([pos], dtype=torch.int64)
logits = self._model(inps[:, pos : pos + 1], pos_tensor)
logits = self._model(
inps[:, pos : pos + 1], {"input_pos": pos_tensor}
)
result_logits.append(logits)
if self._generate_full_logits:
return torch.cat(result_logits, dim=1)
Expand All @@ -74,7 +76,9 @@ def _model_call(self, inps):
else:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
logits = self._model(
inps[:, : self._max_seq_length], {"input_pos": pos_tensor}
)
return logits

else:
Expand Down
4 changes: 3 additions & 1 deletion examples/models/llama/evaluate/eager_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def _model_call(self, inps):
if self._use_kv_cache:
pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device)
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
logits = self._model(
inps[:, : self._max_seq_length], {"input_pos": pos_tensor}
)
return logits
else:
return self._model(inps)
Expand Down
41 changes: 24 additions & 17 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

# Please refer to README.md in the same folder for more information.

from typing import Optional
from typing import Any, Optional, Tuple, Union

import torch
import torch.nn.functional as F

from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
from executorch.examples.models.llama.attention import (
ATTENTION_REGISTRY,
ForwardOptions,
)

from executorch.examples.models.llama.model_args import ModelArgs

Expand Down Expand Up @@ -148,17 +151,17 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN
h = self.attention.forward(
self.attention_norm(x), freqs_cos, freqs_sin, input_pos=input_pos
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
h, attn_options_update = self.attention.forward(
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
)

h = x + h
if hasattr(self, "block_sparse_moe"):
out = h + self.block_sparse_moe(self.ffn_norm(h))
else:
out = h + self.feed_forward(self.ffn_norm(h))
return out
return out, attn_options_update


class Transformer(nn.Module):
Expand All @@ -185,27 +188,28 @@ def __init__(self, params: ModelArgs):
def forward(
self,
tokens: Optional[torch.LongTensor] = None, # tokens
input_pos: Optional[
torch.LongTensor
] = None, # Scalar tensor indicating size of window of the caches
attn_options: Optional[ForwardOptions] = None,
h: Optional[torch.FloatTensor] = None, # embeddings
) -> torch.Tensor:
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[Any]]]:
if (tokens is None) ^ (h is not None):
raise ValueError(
"You cannot specify both tokens and h at the same time, and must specify either one"
)
if tokens is not None and h is None:
h = self.tok_embeddings(tokens)

if attn_options is None:
attn_options = {}
seqlen = h.shape[1]
freqs_cos, freqs_sin = self.rope.get_freqs(input_pos, seqlen)
freqs_cos, freqs_sin = self.rope.get_freqs(
attn_options.get("input_pos"), seqlen
)

attn_options_update = None
for layer in self.layers:
h = layer(
h,
freqs_cos,
freqs_sin,
input_pos,
)
h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options)
if attn_options_update is not None:
attn_options.update(**attn_options_update)

if not self.generate_full_logits:
# Only the last logit is used for the new generated token
Expand Down Expand Up @@ -237,4 +241,7 @@ def forward(
expanded_logits[:, list(self.output_prune_map.values())] = logits
logits = expanded_logits

if attn_options_update is not None:
return logits, attn_options_update

return logits
10 changes: 6 additions & 4 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,18 @@ def get_example_inputs_kvcache_sdpa(self):
if self.enable_dynamic_shape:
return (
torch.tensor([[2, 3, 4]], dtype=torch.long),
torch.tensor([0], dtype=torch.long),
{"input_pos": torch.tensor([0], dtype=torch.long)},
)
else:
return (
torch.tensor(
[[1]], dtype=torch.long
), # tokens, with kv cache our input token length is always just 1 token.
torch.tensor(
[0], dtype=torch.long
), # start_pos, what token of output are we on.
{
"input_pos": torch.tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because we are plugging in the Forward option down the line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, "input_pos" is the only thing the regular attention needs, other implementations can take additional/different parameters.

[0], dtype=torch.long
) # start_pos, what token of output are we on.
},
)

def _transform_for_pre_quantization(self, checkpoint, model_args):
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def forward(
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(tokens=tokens, input_pos=input_pos)
return self.model.forward(tokens, {"input_pos": input_pos})


def build_args_parser() -> argparse.ArgumentParser:
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self, llava):
self.text_model = llava.text_model

def forward(self, input_pos, embeddings):
return self.text_model(None, input_pos, embeddings)
return self.text_model(None, {"input_pos": input_pos}, embeddings)

llava_text_model = LlavaTextModel(llava)

Expand Down
6 changes: 4 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def step(
) -> torch.Tensor:
"""Input is one token. Return logits for next token."""
token_embeds = self.embed_tokens(token).unsqueeze(0)
return self.text_model.forward(None, input_pos, token_embeds)
return self.text_model.forward(None, {"input_pos": input_pos}, token_embeds)

def image_embedding(self, images: torch.Tensor) -> torch.Tensor:
preprocessed_img = self.image_preprocess(images)
Expand Down Expand Up @@ -236,7 +236,9 @@ def prefill(
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
# returns the prefilled token length too, because the text model generates one logits in each forward call.
return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)
return embeds.shape[1], self.text_model.forward(
None, {"input_pos": torch.tensor([0])}, embeds
)

# reference prefill using the text model in HF
def prefill_ref(
Expand Down
4 changes: 2 additions & 2 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_dynamic_shape(self) -> Any:
self.dynamic_shapes = ({1: dim},)
elif self.enable_dynamic_shape:
# Two input arguments: tokens and input_pos but input_pos is static shape
self.dynamic_shapes = ({1: dim}, {0: 1})
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
else:
# Two input arguments: tokens and input_pos but both are of static shape
self.dynamic_shapes = None
Expand Down Expand Up @@ -270,7 +270,7 @@ def calibrate_template(
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
torch.tensor((pos,)),
{"input_pos": torch.tensor((pos,))},
)
pos += 1
if pos >= len(token_list):
Expand Down
Loading