Skip to content

Factor out eager val from eval_llama_lib #3756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
from typing import Any, Callable, List, Optional

import torch

try:
from ...portable.utils import export_to_edge, save_pte_program
except ImportError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this as a task to fix separately. I've seem this issue in other parts of the repo too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense. Let me update

# Workaround to bypass the different paths between executorch pip package and directly python call
# TODO: remove this try catch workaround and have a standard wa to import portable.utils
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `examples.portable.utils`.
from examples.portable.utils import export_to_edge, save_pte_program
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
Expand All @@ -33,7 +41,6 @@
from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer
from torch.nn.attention import SDPBackend

from ...portable.utils import export_to_edge, save_pte_program
from ..model_factory import EagerModelFactory

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down
112 changes: 2 additions & 110 deletions examples/models/llama2/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from typing import Optional, Union

import lm_eval
import torch
from executorch.examples.models.llama2.evaluate import EagerEvalWrapper, evaluate_model
from executorch.examples.models.llama2.export_llama_lib import (
get_quantizer_and_quant_params,
)
Expand All @@ -20,11 +20,6 @@
)

from lm_eval.api.model import LM
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM as eval_wrapper
from lm_eval.tasks import get_task_dict

from torch import nn

from .builder import LlamaEdgeManager
from .export_llama_lib import (
Expand All @@ -33,75 +28,6 @@
)


class EagerEvalWrapper(eval_wrapper):
"""
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
"""

def __init__(
self,
model: nn.Module,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
super().__init__(device=device)
self._model = model
self._tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
self._use_kv_cache = use_kv_cache

@property
def eot_token_id(self):
return self._tokenizer.eos_id

@property
def max_length(self):
return self._max_seq_length

@property
def max_gen_toks(self):
return 50

@property
def batch_size(self):
return 1

@property
def device(self):
return self._device

def tok_encode(self, string: str, **kwargs):
tokens = self._tokenizer.encode(string, bos=True, eos=False)
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
# encoded is a pytorch tensor, but some internal logic in the
# eval harness expects it to be a list instead
# TODO: verify this for multi-batch as well
encoded = encoded.tolist()
return encoded

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
return decoded

def _model_call(self, inps):
if self._use_kv_cache:
pos_tensor = torch.arange(
self._max_seq_length, dtype=torch.int64, device=self.device
)

# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits
else:
return self._model(inps)

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


class ETPybindEvalWrapper(EagerEvalWrapper):
"""
A wrapper class for ExecuTorch py-binded integration with the
Expand Down Expand Up @@ -165,40 +91,6 @@ def _model_call(self, inps):
pass


@torch.no_grad()
def eval(
eval_wrapper: LM,
tasks: Optional[list] = None,
limit: Optional[int] = None,
) -> dict:
"""
Evaluates a language model on a specified task using the lm-evaluation-harness library.

Args:
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
task (str): The name of the evaluation task to perform.
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).

Returns:
eval_results (dict): A dictionary of evaluation results for the specified task(s).
"""

if tasks is None:
tasks = ["wikitext"]

if "hendrycks_test" in tasks:
tasks.remove("hendrycks_test")
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
task_dict = get_task_dict(tasks)

eval_results = evaluate(
eval_wrapper,
task_dict,
limit=limit,
)
return eval_results


def gen_eval_wrapper(
model_name: str,
args: argparse.ArgumentParser,
Expand Down Expand Up @@ -307,7 +199,7 @@ def eval_llama(
eval_wrapper = gen_eval_wrapper(model_name, args)

# Evaluate the model
eval_results = eval(
eval_results = evaluate_model(
eval_wrapper,
args.tasks,
args.limit,
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama2/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .eager_eval import EagerEvalWrapper, evaluate_model

__all__ = [
"evaluate_model",
"EagerEvalWrapper",
]
125 changes: 125 additions & 0 deletions examples/models/llama2/evaluate/eager_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional, Union

import lm_eval
import torch
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
from executorch.examples.models.llama2.tokenizer.tokenizer import (
Tokenizer as SentencePieceTokenizer,
)

from lm_eval.api.model import LM
from lm_eval.evaluator import evaluate
from lm_eval.models.huggingface import HFLM as eval_wrapper
from lm_eval.tasks import get_task_dict

from torch import nn


class EagerEvalWrapper(eval_wrapper):
"""
A wrapper class based on GPTFast, providing integration with the lm-evaluation-harness library.
"""

def __init__(
self,
model: nn.Module,
tokenizer: Union[SentencePieceTokenizer, Tiktoken],
max_seq_length: Optional[int] = None,
use_kv_cache: bool = False,
):
device = "cuda" if torch.cuda.is_available() else "cpu"
super().__init__(device=device)
self._model = model
self._tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
self._use_kv_cache = use_kv_cache

@property
def eot_token_id(self):
return self._tokenizer.eos_id

@property
def max_length(self):
return self._max_seq_length

@property
def max_gen_toks(self):
return 50

@property
def batch_size(self):
return 1

@property
def device(self):
return self._device

def tok_encode(self, string: str, **kwargs):
tokens = self._tokenizer.encode(string, bos=True, eos=False)
encoded = torch.tensor(tokens, dtype=torch.int, device=self.device)
# encoded is a pytorch tensor, but some internal logic in the
# eval harness expects it to be a list instead
# TODO: verify this for multi-batch as well
encoded = encoded.tolist()
return encoded

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
return decoded

def _model_call(self, inps):
if self._use_kv_cache:
pos_tensor = torch.arange(
self._max_seq_length, dtype=torch.int64, device=self.device
)

# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits
else:
return self._model(inps)

def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")


@torch.no_grad()
def evaluate_model(
eval_wrapper: LM,
tasks: Optional[list] = None,
limit: Optional[int] = None,
) -> dict:
"""
Evaluates a language model on a specified task using the lm-evaluation-harness library.

Args:
eval_wrapper (LM): A LM wrapper class compatible with lm-evaluation-harness evaluation
task (str): The name of the evaluation task to perform.
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).

Returns:
eval_results (dict): A dictionary of evaluation results for the specified task(s).
"""

if tasks is None:
tasks = ["wikitext"]

if "hendrycks_test" in tasks:
tasks.remove("hendrycks_test")
tasks += list(lm_eval.tasks.hendrycks_test.create_all_tasks().keys())
task_dict = get_task_dict(tasks)

eval_results = evaluate(
eval_wrapper,
task_dict,
limit=limit,
)
return eval_results
Loading